diff --git a/image_ref/main.py b/image_ref/main.py index 767cce8e7ce199cc67127807ee3940c006154f86..62321f2899ab255ad6a7e648998888f03ac2c893 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -60,7 +60,7 @@ def test_duo(model, data_test, loss_function, epoch): label = label.cuda() label_class = torch.argmin(label).data.cpu().numpy() pred_logits = model.forward(imaer,imana,img_ref) - pred_class = torch.argmax(pred_logits[:,0]) + pred_class = torch.argmax(pred_logits[:,0]).tolist() acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label.data.cpu().numpy()).sum().item() acc += (pred_class==label_class) loss = loss_function(pred_logits,label) @@ -91,6 +91,7 @@ def run_duo(args): train_acc=[] train_loss=[] val_acc=[] + val_cont_acc=[] val_loss=[] #init training loss_function = nn.CrossEntropyLoss() @@ -104,14 +105,16 @@ def run_duo(args): loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e) val_loss.append(loss) val_acc.append(acc) + val_cont_acc.append(acc_contrastive) if loss < best_loss : save_model(model,args.save_path) best_loss = loss # plot and save training figs plt.clf() plt.subplot(2, 1, 1) - plt.plot(train_acc, label='train') - plt.plot(val_acc, label='val') + plt.plot(train_acc, label='train cont acc') + plt.plot(val_cont_acc, label='val cont acc') + plt.plot(val_acc, label='val classification acc') plt.title('Train and validation accuracy') plt.xlabel('epoch') plt.ylabel('accuracy')