Skip to content
Snippets Groups Projects
Commit 8837c011 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : class acc val

parent 10228745
No related branches found
No related tags found
No related merge requests found
......@@ -60,7 +60,7 @@ def test_duo(model, data_test, loss_function, epoch):
label = label.cuda()
label_class = torch.argmin(label).data.cpu().numpy()
pred_logits = model.forward(imaer,imana,img_ref)
pred_class = torch.argmax(pred_logits[:,0])
pred_class = torch.argmax(pred_logits[:,0]).tolist()
acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label.data.cpu().numpy()).sum().item()
acc += (pred_class==label_class)
loss = loss_function(pred_logits,label)
......@@ -91,6 +91,7 @@ def run_duo(args):
train_acc=[]
train_loss=[]
val_acc=[]
val_cont_acc=[]
val_loss=[]
#init training
loss_function = nn.CrossEntropyLoss()
......@@ -104,14 +105,16 @@ def run_duo(args):
loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e)
val_loss.append(loss)
val_acc.append(acc)
val_cont_acc.append(acc_contrastive)
if loss < best_loss :
save_model(model,args.save_path)
best_loss = loss
# plot and save training figs
plt.clf()
plt.subplot(2, 1, 1)
plt.plot(train_acc, label='train')
plt.plot(val_acc, label='val')
plt.plot(train_acc, label='train cont acc')
plt.plot(val_cont_acc, label='val cont acc')
plt.plot(val_acc, label='val classification acc')
plt.title('Train and validation accuracy')
plt.xlabel('epoch')
plt.ylabel('accuracy')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment