diff --git a/image_ref/main.py b/image_ref/main.py
index 5cac8f3d24da4c33bc62f49e898db342ed063b06..001af33fd3057f11f11eeec54377e0d3c6b6976e 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -5,7 +5,7 @@ from config import load_args_contrastive
 from dataset_ref import load_data_duo
 import torch
 import torch.nn as nn
-from model import Classification_model_contrastive, Classification_model_duo_contrastive
+from model import Classification_model_duo_contrastive
 import torch.optim as optim
 from sklearn.metrics import confusion_matrix
 import seaborn as sn
@@ -42,6 +42,7 @@ def test_duo(model, data_test, loss_function, epoch):
     model.eval()
     losses = 0.
     acc = 0.
+    acc_contrastive = 0.
     for param in model.parameters():
         param.requires_grad = False
 
@@ -57,15 +58,18 @@ def test_duo(model, data_test, loss_function, epoch):
             imana = imana.cuda()
             img_ref = img_ref.cuda()
             label = label.cuda()
+        label_class = torch.argmin(label).data.cpu().numpy()
         pred_logits = model.forward(imaer,imana,img_ref)
         pred_class = torch.argmax(pred_logits[:,0])
-        acc += (pred_class==label).sum().item()
+        acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label).sum().item()
+        acc += (pred_class==label_class)
         loss = loss_function(pred_logits,label)
         losses += loss.item()
     losses = losses/(label.shape[0]*len(data_test.dataset))
-    acc = acc/(label.shape[0]*len(data_test.dataset))
-    print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
-    return losses,acc
+    acc = acc/(len(data_test.dataset))
+    acc_contrastive = acc_contrastive /(label.shape[0]*len(data_test.dataset))
+    print('Test epoch {}, loss : {:.3f}  acc : {:.3f} acc contrastive : {:.3f}'.format(epoch,losses,acc,acc_contrastive))
+    return losses,acc,acc_contrastive
 
 def run_duo(args):
     #load data
@@ -97,7 +101,7 @@ def run_duo(args):
         train_loss.append(loss)
         train_acc.append(acc)
         if e%args.eval_inter==0 :
-            loss, acc = test_duo(model,data_test_batch,loss_function,e)
+            loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e)
             val_loss.append(loss)
             val_acc.append(acc)
             if loss < best_loss :