diff --git a/image_ref/config.py b/image_ref/config.py
index 4f5fb2143771a191183409392a3769b9a166e32d..8b38951a90fb158c810fdc03e0d95c77e17d65c6 100644
--- a/image_ref/config.py
+++ b/image_ref/config.py
@@ -12,9 +12,10 @@ def load_args_contrastive():
     parser.add_argument('--batch_size', type=int, default=64)
     parser.add_argument('--positive_prop', type=int, default=30)
     parser.add_argument('--model', type=str, default='ResNet18')
-    parser.add_argument('--sampler', type=str, default=None)
+    parser.add_argument('--sampler', type=str, default=None) #'balanced' for weighted oversampling
     parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive')
     parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive')
+    parser.add_argument('--dataset_test_dir', type=str, default=None)
     parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
     parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
     parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py
index e5dc9871e532a646d4850db905c5148620c3c575..51aa24ee3c0faa4a48664b09b43f5229a790729b 100644
--- a/image_ref/dataset_ref.py
+++ b/image_ref/dataset_ref.py
@@ -154,7 +154,7 @@ class ImageFolderDuo(data.Dataset):
     def __len__(self):
         return len(self.imlist)
 
-def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None):
+def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None):
 
 
 
@@ -182,9 +182,13 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
     print('Default val transform')
 
     train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform)
-    val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform)
+    val_dataset = ImageFolderDuo_Batched(root=base_dir_val, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform)
 
-    if sampler =='weighted' :
+    if base_dir_test is not None :
+        test_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir=ref_dir,
+                                             ref_transform=ref_transform)
+
+    if sampler =='balanced' :
         y_train_label = np.array([i for (_,_,i)in train_dataset.imlist])
         class_sample_count = np.array([len(np.where(y_train_label == t)[0]) for t in np.unique(y_train_label)])
         weight = 1. / class_sample_count
@@ -211,7 +215,7 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
             pin_memory=False,
         )
 
-    data_loader_test = data.DataLoader(
+    data_loader_val = data.DataLoader(
         dataset=val_dataset,
         batch_size=1,
         shuffle=shuffle,
@@ -220,7 +224,19 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
         pin_memory=False,
     )
 
-    return data_loader_train, data_loader_test
+    if base_dir_test is not None :
+        data_loader_test = data.DataLoader(
+            dataset=test_dataset,
+            batch_size=1,
+            shuffle=shuffle,
+            num_workers=0,
+            collate_fn=None,
+            pin_memory=False,
+        )
+    else :
+        data_loader_test = None
+
+    return data_loader_train, data_loader_val, data_loader_test
 
 
 class ImageFolderDuo_Batched(data.Dataset):
diff --git a/image_ref/main.py b/image_ref/main.py
index 8b261469b2c9d34e5bb29b2d4866c41900c2abca..8e22bdc73bd45aaab2c086962e2304c677667e01 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -13,6 +13,7 @@ from sklearn.metrics import confusion_matrix
 import seaborn as sn
 import pandas as pd
 
+
 def train_duo(model, data_train, optimizer, loss_function, epoch, wandb):
     model.train()
     losses = 0.
@@ -20,32 +21,32 @@ def train_duo(model, data_train, optimizer, loss_function, epoch, wandb):
     for param in model.parameters():
         param.requires_grad = True
 
-    for imaer,imana, img_ref, label in data_train:
+    for imaer, imana, img_ref, label in data_train:
         label = label.long()
         if torch.cuda.is_available():
             imaer = imaer.cuda()
             imana = imana.cuda()
             img_ref = img_ref.cuda()
             label = label.cuda()
-        pred_logits = model.forward(imaer,imana,img_ref)
-        pred_class = torch.argmax(pred_logits,dim=1)
-        acc += (pred_class==label).sum().item()
-        loss = loss_function(pred_logits,label)
+        pred_logits = model.forward(imaer, imana, img_ref)
+        pred_class = torch.argmax(pred_logits, dim=1)
+        acc += (pred_class == label).sum().item()
+        loss = loss_function(pred_logits, label)
         losses += loss.item()
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
-    losses = losses/len(data_train.dataset)
-    acc = acc/len(data_train.dataset)
-    print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
+    losses = losses / len(data_train.dataset)
+    acc = acc / len(data_train.dataset)
+    print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch, losses, acc))
 
     if wandb is not None:
-        wdb.log({"train loss": losses, 'train epoch': epoch, "train contrastive accuracy": acc })
-
+        wdb.log({"train loss": losses, 'train epoch': epoch, "train contrastive accuracy": acc})
 
     return losses, acc
 
-def test_duo(model, data_test, loss_function, epoch, wandb):
+
+def val_duo(model, data_test, loss_function, epoch, wandb):
     model.eval()
     losses = 0.
     acc = 0.
@@ -53,11 +54,11 @@ def test_duo(model, data_test, loss_function, epoch, wandb):
     for param in model.parameters():
         param.requires_grad = False
 
-    for imaer,imana, img_ref, label in data_test:
-        imaer = imaer.transpose(0,1)
-        imana = imana.transpose(0,1)
-        img_ref = img_ref.transpose(0,1)
-        label = label.transpose(0,1)
+    for imaer, imana, img_ref, label in data_test:
+        imaer = imaer.transpose(0, 1)
+        imana = imana.transpose(0, 1)
+        img_ref = img_ref.transpose(0, 1)
+        label = label.transpose(0, 1)
         label = label.squeeze()
         label = label.long()
         if torch.cuda.is_available():
@@ -66,75 +67,86 @@ def test_duo(model, data_test, loss_function, epoch, wandb):
             img_ref = img_ref.cuda()
             label = label.cuda()
         label_class = torch.argmin(label).data.cpu().numpy()
-        pred_logits = model.forward(imaer,imana,img_ref)
-        pred_class = torch.argmax(pred_logits[:,0]).tolist()
-        acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label.data.cpu().numpy()).sum().item()
-        acc += (pred_class==label_class)
-        loss = loss_function(pred_logits,label)
+        pred_logits = model.forward(imaer, imana, img_ref)
+        pred_class = torch.argmax(pred_logits[:, 0]).tolist()
+        acc_contrastive += (
+                    torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item()
+        acc += (pred_class == label_class)
+        loss = loss_function(pred_logits, label)
         losses += loss.item()
-    losses = losses/(label.shape[0]*len(data_test.dataset))
-    acc = acc/(len(data_test.dataset))
-    acc_contrastive = acc_contrastive /(label.shape[0]*len(data_test.dataset))
-    print('Test epoch {}, loss : {:.3f}  acc : {:.3f} acc contrastive : {:.3f}'.format(epoch,losses,acc,acc_contrastive))
+    losses = losses / (label.shape[0] * len(data_test.dataset))
+    acc = acc / (len(data_test.dataset))
+    acc_contrastive = acc_contrastive / (label.shape[0] * len(data_test.dataset))
+    print('Test epoch {}, loss : {:.3f}  acc : {:.3f} acc contrastive : {:.3f}'.format(epoch, losses, acc,
+                                                                                       acc_contrastive))
 
     if wandb is not None:
-        wdb.log({"validation loss": losses, 'validation epoch': epoch, "validation classification accuracy": acc, "validation contrastive accuracy" : acc_contrastive })
+        wdb.log({"validation loss": losses, 'validation epoch': epoch, "validation classification accuracy": acc,
+                 "validation contrastive accuracy": acc_contrastive})
 
-    return losses,acc,acc_contrastive
+    return losses, acc, acc_contrastive
 
-def run_duo(args):
 
-    #wandb init
+def run_duo(args):
+    # wandb init
     if args.wandb is not None:
         os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
 
         os.environ["WANDB_MODE"] = "offline"
         os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
 
-        wdb.init(project="Intensity prediction", dir='./wandb_run', name=args.wandb)
+        wdb.init(project="contrastive_classification", dir='./wandb_run', name=args.wandb)
 
-    #load data
-    data_train, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_test=args.dataset_val_dir, batch_size=args.batch_size,
-                                          ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop, sampler=args.sampler)
+    # load data
+    data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir,
+                                                                base_dir_val=args.dataset_val_dir,
+                                                                base_dir_test=args.dataset_test_dir,
+                                                                batch_size=args.batch_size,
+                                                                ref_dir=args.dataset_ref_dir,
+                                                                positive_prop=args.positive_prop, sampler=args.sampler)
 
-    #load model
-    model = Classification_model_duo_contrastive(model = args.model, n_class=2)
+    # load model
+    model = Classification_model_duo_contrastive(model=args.model, n_class=2)
     model.double()
-    #load weight
-    if args.pretrain_path is not None :
+    # load weight
+    if args.pretrain_path is not None:
         print('Model weight loaded')
-        load_model(model,args.pretrain_path)
-    #move parameters to GPU
+        load_model(model, args.pretrain_path)
+    # move parameters to GPU
     if torch.cuda.is_available():
         print('Model loaded on GPU')
         model = model.cuda()
 
-    #init accumulators
+    # init accumulators
     best_loss = 100
-    train_acc=[]
-    train_loss=[]
-    val_acc=[]
-    val_cont_acc=[]
-    val_loss=[]
-    #init training
+    train_acc = []
+    train_loss = []
+    val_acc = []
+    val_cont_acc = []
+    val_loss = []
+    # init training
     loss_function = nn.CrossEntropyLoss()
     optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
-    #train model
+    # train model
     for e in range(args.epoches):
-        loss, acc = train_duo(model,data_train,optimizer,loss_function,e,args.wandb)
+        loss, acc = train_duo(model, data_train, optimizer, loss_function, e, args.wandb)
         train_loss.append(loss)
         train_acc.append(acc)
-        if e%args.eval_inter==0 :
-            loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e,args.wandb)
+        if e % args.eval_inter == 0:
+            loss, acc, acc_contrastive = val_duo(model, data_val_batch, loss_function, e, args.wandb)
             val_loss.append(loss)
             val_acc.append(acc)
             val_cont_acc.append(acc_contrastive)
-            if loss < best_loss :
-                save_model(model,args.save_path)
+            if loss < best_loss:
+                save_model(model, args.save_path)
                 best_loss = loss
+        if e % args.test_inter == 0 and args.dataset_test_dir is not None:
+            loss, acc, acc_contrastive = val_duo(model, data_test_batch, loss_function, e, args.wandb)
+            val_loss.append(loss)
+            val_acc.append(acc)
+            val_cont_acc.append(acc_contrastive)
     # plot and save training figs
     if args.wandb is None:
-
         plt.clf()
         plt.subplot(2, 1, 1)
         plt.plot(train_acc, label='train cont acc')
@@ -159,10 +171,16 @@ def run_duo(args):
         plt.show()
         plt.savefig('output/training_plot_contrastive_{}.png'.format(args.positive_prop))
 
-    #load and evaluate best model
+    # load and evaluate best model
     load_model(model, args.save_path)
-    make_prediction_duo(model,data_test_batch, 'output/confusion_matrix_contractive_{}_bis.png'.format(args.positive_prop),
-                        'output/confidence_matrix_contractive_{}_bis.png'.format(args.positive_prop))
+    if args.args.dataset_test_dir is not None :
+        make_prediction_duo(model, data_test_batch,
+                            'output/confusion_matrix_contractive_{}_bis_test.png'.format(args.positive_prop),
+                            'output/confidence_matrix_contractive_{}_bis_test.png'.format(args.positive_prop))
+
+    make_prediction_duo(model, data_val_batch,
+                            'output/confusion_matrix_contractive_{}_bis_val.png'.format(args.positive_prop),
+                            'output/confidence_matrix_contractive_{}_bis_val.png'.format(args.positive_prop))
 
     if args.wandb is not None:
         wdb.finish()
@@ -177,26 +195,25 @@ def make_prediction_duo(model, data, f_name, f_name2):
     y_true = []
     soft_max = nn.Softmax(dim=1)
     # iterate over test data
-    for imaer,imana,img_ref, label in data:
-        imaer = imaer.transpose(0,1)
-        imana = imana.transpose(0,1)
-        img_ref = img_ref.transpose(0,1)
-        label = label.transpose(0,1)
+    for imaer, imana, img_ref, label in data:
+        imaer = imaer.transpose(0, 1)
+        imana = imana.transpose(0, 1)
+        img_ref = img_ref.transpose(0, 1)
+        label = label.transpose(0, 1)
         label = label.squeeze()
         label = label.long()
         specie = torch.argmin(label)
 
-
         if torch.cuda.is_available():
             imaer = imaer.cuda()
             imana = imana.cuda()
             img_ref = img_ref.cuda()
             label = label.cuda()
-        output = model(imaer,imana,img_ref)
+        output = model(imaer, imana, img_ref)
         confidence = soft_max(output)
-        confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy())
-        #Mono class output (only most postive paire)
-        output = torch.argmax(output[:,0])
+        confidence_pred_list[specie].append(confidence[:, 0].data.cpu().numpy())
+        # Mono class output (only most postive paire)
+        output = torch.argmax(output[:, 0])
         label = torch.argmin(label)
         y_pred.append(output.tolist())
         y_true.append(label.tolist())  # Save Truth
@@ -205,9 +222,9 @@ def make_prediction_duo(model, data, f_name, f_name2):
     # Build confusion matrix
     classes = data.dataset.classes
     cf_matrix = confusion_matrix(y_true, y_pred)
-    confidence_matrix = np.zeros((n_class,n_class))
+    confidence_matrix = np.zeros((n_class, n_class))
     for i in range(n_class):
-        confidence_matrix[i]=np.mean(confidence_pred_list[i],axis=0)
+        confidence_matrix[i] = np.mean(confidence_pred_list[i], axis=0)
 
     df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
                          columns=[i for i in classes])
@@ -230,12 +247,12 @@ def save_model(model, path):
     print('Model saved')
     torch.save(model.state_dict(), path)
 
+
 def load_model(model, path):
     model.load_state_dict(torch.load(path, weights_only=True))
 
 
-
 if __name__ == '__main__':
     args = load_args_contrastive()
     print(args)
-    run_duo(args)
\ No newline at end of file
+    run_duo(args)