diff --git a/config/config.py b/config/config.py
index 4975373d77200ddcd1fee00e9d2b19b59ebd6193..902bbbc1ed3242e377c39d53f81ebedc0fc39d4a 100644
--- a/config/config.py
+++ b/config/config.py
@@ -32,7 +32,7 @@ def load_args_contrastive():
     parser.add_argument('--model', type=str, default='ResNet18')
     parser.add_argument('--model_type', type=str, default='duo')
     parser.add_argument('--dataset_dir', type=str, default='../data/processed_data/npy_image/data_training')
-    parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
+    parser.add_argument('--dataset_ref_dir', type=str, default='img_ref')
     parser.add_argument('--output', type=str, default='../output/out_contrastive.csv')
     parser.add_argument('--save_path', type=str, default='../output/best_model_constrastive.pt')
     parser.add_argument('--pretrain_path', type=str, default=None)
diff --git a/dataset/dataset_ref.py b/dataset/dataset_ref.py
index b77f9c7e09e7e30edd292f3a6afab1e97ddbe46a..45fd094d6321fe1a935149c8846113b4125f2f67 100644
--- a/dataset/dataset_ref.py
+++ b/dataset/dataset_ref.py
@@ -127,13 +127,15 @@ class ImageFolderDuo(data.Dataset):
         imgANA = self.loader(impathANA)
         label_ref = np.random.randint(0,len(self.classes)-1)
         class_ref = self.classes[label_ref]
-        path_ref = self.ref_dir + class_ref +'.npy'
+        path_ref = self.ref_dir +'/'+ class_ref + '.npy'
+        img_ref = self.loader(path_ref)
         if self.transform is not None:
             imgAER = self.transform(imgAER)
             imgANA = self.transform(imgANA)
+            img_ref = self.transform(img_ref)
         if self.target_transform is not None:
             target = 0 if self.target_transform(target) == label_ref else 1
-        img_ref = self.loader(path_ref)
+
         return imgAER, imgANA, img_ref, target
 
     def __len__(self):
diff --git a/image_ref/img_ref/Citrobacter freundii.npy b/image_ref/img_ref/Citrobacter freundii.npy
new file mode 100644
index 0000000000000000000000000000000000000000..9edcae93f12cb1afcc253a6ee7ea0ca106d9c368
Binary files /dev/null and b/image_ref/img_ref/Citrobacter freundii.npy differ
diff --git a/image_ref/img_ref/Citrobacter freundii.png b/image_ref/img_ref/Citrobacter freundii.png
new file mode 100644
index 0000000000000000000000000000000000000000..bb0c83b2c79f37c029b02ac7ea3ff3515e8f7729
Binary files /dev/null and b/image_ref/img_ref/Citrobacter freundii.png differ
diff --git a/image_ref/img_ref/Enterobacter hormaechei.npy b/image_ref/img_ref/Enterobacter hormaechei.npy
new file mode 100644
index 0000000000000000000000000000000000000000..633057f00780be8852cda8d4b52782b5c51377b0
Binary files /dev/null and b/image_ref/img_ref/Enterobacter hormaechei.npy differ
diff --git a/image_ref/img_ref/Enterobacter hormaechei.png b/image_ref/img_ref/Enterobacter hormaechei.png
new file mode 100644
index 0000000000000000000000000000000000000000..67de2dc3a2022ca1dd5ac38b257a12721171a0cb
Binary files /dev/null and b/image_ref/img_ref/Enterobacter hormaechei.png differ
diff --git a/image_ref/img_ref/Klebsiella oxytoca.npy b/image_ref/img_ref/Klebsiella oxytoca.npy
new file mode 100644
index 0000000000000000000000000000000000000000..c43f586d30716e0d4eedd52ab0e4be34459635c2
Binary files /dev/null and b/image_ref/img_ref/Klebsiella oxytoca.npy differ
diff --git a/image_ref/img_ref/Klebsiella oxytoca.png b/image_ref/img_ref/Klebsiella oxytoca.png
new file mode 100644
index 0000000000000000000000000000000000000000..c38c76634ef58a36a3c75241dace225bf220ba6b
Binary files /dev/null and b/image_ref/img_ref/Klebsiella oxytoca.png differ
diff --git a/image_ref/img_ref/Klebsiella pneumoniae.npy b/image_ref/img_ref/Klebsiella pneumoniae.npy
new file mode 100644
index 0000000000000000000000000000000000000000..701ad33c6557d615d252a6dfcf5f233743e9780c
Binary files /dev/null and b/image_ref/img_ref/Klebsiella pneumoniae.npy differ
diff --git a/image_ref/img_ref/Klebsiella pneumoniae.png b/image_ref/img_ref/Klebsiella pneumoniae.png
new file mode 100644
index 0000000000000000000000000000000000000000..3933455b5802a3c3bc76b0f039e5f7a3fe4ffd83
Binary files /dev/null and b/image_ref/img_ref/Klebsiella pneumoniae.png differ
diff --git a/image_ref/img_ref/Proteus mirabilis.npy b/image_ref/img_ref/Proteus mirabilis.npy
new file mode 100644
index 0000000000000000000000000000000000000000..35a5f5392fa0be0dbb52fa061bd1f1116a1a4f40
Binary files /dev/null and b/image_ref/img_ref/Proteus mirabilis.npy differ
diff --git a/image_ref/img_ref/Proteus mirabilis.png b/image_ref/img_ref/Proteus mirabilis.png
new file mode 100644
index 0000000000000000000000000000000000000000..c0276c7d4082bd653e79525861022994c20960bf
Binary files /dev/null and b/image_ref/img_ref/Proteus mirabilis.png differ
diff --git a/image_ref/main.py b/image_ref/main.py
index b659004dffa105a4abcc92af90722fbadcb67362..f4590a29e0fdf414869185ed7775934b4138378c 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -6,7 +6,7 @@
 import matplotlib.pyplot as plt
 import numpy as np
 
-from config.config import load_args
+from config.config import load_args, load_args_contrastive
 from dataset.dataset_ref import load_data, load_data_duo
 import torch
 import torch.nn as nn
@@ -146,13 +146,14 @@ def train_duo(model, data_train, optimizer, loss_function, epoch):
     for param in model.parameters():
         param.requires_grad = True
 
-    for imaer,imana, label in data_train:
+    for imaer,imana, img_ref, label in data_train:
         label = label.long()
         if torch.cuda.is_available():
             imaer = imaer.cuda()
             imana = imana.cuda()
+            img_ref = img_ref.cuda()
             label = label.cuda()
-        pred_logits = model.forward(imaer,imana)
+        pred_logits = model.forward(imaer,imana,img_ref)
         pred_class = torch.argmax(pred_logits,dim=1)
         acc += (pred_class==label).sum().item()
         loss = loss_function(pred_logits,label)
@@ -172,13 +173,14 @@ def test_duo(model, data_test, loss_function, epoch):
     for param in model.parameters():
         param.requires_grad = False
 
-    for imaer,imana, label in data_test:
+    for imaer,imana, img_ref, label in data_test:
         label = label.long()
         if torch.cuda.is_available():
             imaer = imaer.cuda()
             imana = imana.cuda()
+            img_ref = img_ref.cuda()
             label = label.cuda()
-        pred_logits = model.forward(imaer,imana)
+        pred_logits = model.forward(imaer,imana,img_ref)
         pred_class = torch.argmax(pred_logits,dim=1)
         acc += (pred_class==label).sum().item()
         loss = loss_function(pred_logits,label)
@@ -190,7 +192,7 @@ def test_duo(model, data_test, loss_function, epoch):
 
 def run_duo(args):
     #load data
-    data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size, )
+    data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size, ref_dir=args.dataset_ref_dir)
     #load model
     model = Classification_model_duo_contrastive(model = args.model, n_class=len(data_train.dataset.dataset.classes))
     model.double()
@@ -240,13 +242,14 @@ def make_prediction_duo(model, data, f_name):
     y_pred = []
     y_true = []
     # iterate over test data
-    for imaer,imana, label in data:
+    for imaer,imana,img_ref, label in data:
         label = label.long()
         if torch.cuda.is_available():
             imaer = imaer.cuda()
             imana = imana.cuda()
+            img_ref = img_ref.cuda()
             label = label.cuda()
-        output = model(imaer,imana)
+        output = model(imaer,imana,img_ref)
 
         output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
         y_pred.extend(output)
@@ -277,7 +280,7 @@ def load_model(model, path):
 
 
 if __name__ == '__main__':
-    args = load_args()
+    args = load_args_contrastive()
     if args.model_type=='duo':
         run_duo(args)
     else :
diff --git a/image_ref/model.py b/image_ref/model.py
index 0374d1e73f5e94567ab34aec52130b90e9b710dc..1f956957ee41df11683918195c64312f5db2938d 100644
--- a/image_ref/model.py
+++ b/image_ref/model.py
@@ -286,8 +286,8 @@ class Classification_model_duo_contrastive(nn.Module):
 
 
     def forward(self, input_aer, input_ana, input_ref):
-        input_ana = torch.concat(input_ana, input_ref, dim=2)
-        input_aer = torch.concat(input_aer, input_ref, dim=2)
+        input_ana = torch.concat([input_ana, input_ref], dim=1)
+        input_aer = torch.concat([input_aer, input_ref], dim=1)
         out_aer =  self.im_encoder(input_aer)
         out_ana = self.im_encoder(input_ana)
         out = torch.concat([out_aer,out_ana],dim=1)