diff --git a/image_ref/main.py b/image_ref/main.py index 5cac8f3d24da4c33bc62f49e898db342ed063b06..001af33fd3057f11f11eeec54377e0d3c6b6976e 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -5,7 +5,7 @@ from config import load_args_contrastive from dataset_ref import load_data_duo import torch import torch.nn as nn -from model import Classification_model_contrastive, Classification_model_duo_contrastive +from model import Classification_model_duo_contrastive import torch.optim as optim from sklearn.metrics import confusion_matrix import seaborn as sn @@ -42,6 +42,7 @@ def test_duo(model, data_test, loss_function, epoch): model.eval() losses = 0. acc = 0. + acc_contrastive = 0. for param in model.parameters(): param.requires_grad = False @@ -57,15 +58,18 @@ def test_duo(model, data_test, loss_function, epoch): imana = imana.cuda() img_ref = img_ref.cuda() label = label.cuda() + label_class = torch.argmin(label).data.cpu().numpy() pred_logits = model.forward(imaer,imana,img_ref) pred_class = torch.argmax(pred_logits[:,0]) - acc += (pred_class==label).sum().item() + acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label).sum().item() + acc += (pred_class==label_class) loss = loss_function(pred_logits,label) losses += loss.item() losses = losses/(label.shape[0]*len(data_test.dataset)) - acc = acc/(label.shape[0]*len(data_test.dataset)) - print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) - return losses,acc + acc = acc/(len(data_test.dataset)) + acc_contrastive = acc_contrastive /(label.shape[0]*len(data_test.dataset)) + print('Test epoch {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch,losses,acc,acc_contrastive)) + return losses,acc,acc_contrastive def run_duo(args): #load data @@ -97,7 +101,7 @@ def run_duo(args): train_loss.append(loss) train_acc.append(acc) if e%args.eval_inter==0 : - loss, acc = test_duo(model,data_test_batch,loss_function,e) + loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e) val_loss.append(loss) val_acc.append(acc) if loss < best_loss :