diff --git a/image_ref/main.py b/image_ref/main.py
index 767cce8e7ce199cc67127807ee3940c006154f86..62321f2899ab255ad6a7e648998888f03ac2c893 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -60,7 +60,7 @@ def test_duo(model, data_test, loss_function, epoch):
             label = label.cuda()
         label_class = torch.argmin(label).data.cpu().numpy()
         pred_logits = model.forward(imaer,imana,img_ref)
-        pred_class = torch.argmax(pred_logits[:,0])
+        pred_class = torch.argmax(pred_logits[:,0]).tolist()
         acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label.data.cpu().numpy()).sum().item()
         acc += (pred_class==label_class)
         loss = loss_function(pred_logits,label)
@@ -91,6 +91,7 @@ def run_duo(args):
     train_acc=[]
     train_loss=[]
     val_acc=[]
+    val_cont_acc=[]
     val_loss=[]
     #init training
     loss_function = nn.CrossEntropyLoss()
@@ -104,14 +105,16 @@ def run_duo(args):
             loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e)
             val_loss.append(loss)
             val_acc.append(acc)
+            val_cont_acc.append(acc_contrastive)
             if loss < best_loss :
                 save_model(model,args.save_path)
                 best_loss = loss
     # plot and save training figs
     plt.clf()
     plt.subplot(2, 1, 1)
-    plt.plot(train_acc, label='train')
-    plt.plot(val_acc, label='val')
+    plt.plot(train_acc, label='train cont acc')
+    plt.plot(val_cont_acc, label='val cont acc')
+    plt.plot(val_acc, label='val classification acc')
     plt.title('Train and validation accuracy')
     plt.xlabel('epoch')
     plt.ylabel('accuracy')