Skip to content
Snippets Groups Projects
Commit 8837c011 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : class acc val

parent 10228745
No related branches found
No related tags found
No related merge requests found
...@@ -60,7 +60,7 @@ def test_duo(model, data_test, loss_function, epoch): ...@@ -60,7 +60,7 @@ def test_duo(model, data_test, loss_function, epoch):
label = label.cuda() label = label.cuda()
label_class = torch.argmin(label).data.cpu().numpy() label_class = torch.argmin(label).data.cpu().numpy()
pred_logits = model.forward(imaer,imana,img_ref) pred_logits = model.forward(imaer,imana,img_ref)
pred_class = torch.argmax(pred_logits[:,0]) pred_class = torch.argmax(pred_logits[:,0]).tolist()
acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label.data.cpu().numpy()).sum().item() acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label.data.cpu().numpy()).sum().item()
acc += (pred_class==label_class) acc += (pred_class==label_class)
loss = loss_function(pred_logits,label) loss = loss_function(pred_logits,label)
...@@ -91,6 +91,7 @@ def run_duo(args): ...@@ -91,6 +91,7 @@ def run_duo(args):
train_acc=[] train_acc=[]
train_loss=[] train_loss=[]
val_acc=[] val_acc=[]
val_cont_acc=[]
val_loss=[] val_loss=[]
#init training #init training
loss_function = nn.CrossEntropyLoss() loss_function = nn.CrossEntropyLoss()
...@@ -104,14 +105,16 @@ def run_duo(args): ...@@ -104,14 +105,16 @@ def run_duo(args):
loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e) loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e)
val_loss.append(loss) val_loss.append(loss)
val_acc.append(acc) val_acc.append(acc)
val_cont_acc.append(acc_contrastive)
if loss < best_loss : if loss < best_loss :
save_model(model,args.save_path) save_model(model,args.save_path)
best_loss = loss best_loss = loss
# plot and save training figs # plot and save training figs
plt.clf() plt.clf()
plt.subplot(2, 1, 1) plt.subplot(2, 1, 1)
plt.plot(train_acc, label='train') plt.plot(train_acc, label='train cont acc')
plt.plot(val_acc, label='val') plt.plot(val_cont_acc, label='val cont acc')
plt.plot(val_acc, label='val classification acc')
plt.title('Train and validation accuracy') plt.title('Train and validation accuracy')
plt.xlabel('epoch') plt.xlabel('epoch')
plt.ylabel('accuracy') plt.ylabel('accuracy')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment