Skip to content
Snippets Groups Projects
Commit 0ebf7166 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : add contrastive accuracy during test

parent 5a439f53
No related branches found
No related tags found
No related merge requests found
......@@ -5,7 +5,7 @@ from config import load_args_contrastive
from dataset_ref import load_data_duo
import torch
import torch.nn as nn
from model import Classification_model_contrastive, Classification_model_duo_contrastive
from model import Classification_model_duo_contrastive
import torch.optim as optim
from sklearn.metrics import confusion_matrix
import seaborn as sn
......@@ -42,6 +42,7 @@ def test_duo(model, data_test, loss_function, epoch):
model.eval()
losses = 0.
acc = 0.
acc_contrastive = 0.
for param in model.parameters():
param.requires_grad = False
......@@ -57,15 +58,18 @@ def test_duo(model, data_test, loss_function, epoch):
imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda()
label_class = torch.argmin(label).data.cpu().numpy()
pred_logits = model.forward(imaer,imana,img_ref)
pred_class = torch.argmax(pred_logits[:,0])
acc += (pred_class==label).sum().item()
acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label).sum().item()
acc += (pred_class==label_class)
loss = loss_function(pred_logits,label)
losses += loss.item()
losses = losses/(label.shape[0]*len(data_test.dataset))
acc = acc/(label.shape[0]*len(data_test.dataset))
print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
return losses,acc
acc = acc/(len(data_test.dataset))
acc_contrastive = acc_contrastive /(label.shape[0]*len(data_test.dataset))
print('Test epoch {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch,losses,acc,acc_contrastive))
return losses,acc,acc_contrastive
def run_duo(args):
#load data
......@@ -97,7 +101,7 @@ def run_duo(args):
train_loss.append(loss)
train_acc.append(acc)
if e%args.eval_inter==0 :
loss, acc = test_duo(model,data_test_batch,loss_function,e)
loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e)
val_loss.append(loss)
val_acc.append(acc)
if loss < best_loss :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment