Skip to content
Snippets Groups Projects
Commit fb265a02 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : path

parent 9416aaa5
No related branches found
No related tags found
No related merge requests found
......@@ -84,7 +84,7 @@ def run_duo(args):
model = model.cuda()
#init accumulators
best_acc = 0
best_loss = 100
train_acc=[]
train_loss=[]
val_acc=[]
......@@ -101,18 +101,32 @@ def run_duo(args):
loss, acc = test_duo(model,data_test_batch,loss_function,e)
val_loss.append(loss)
val_acc.append(acc)
if acc > best_acc :
if loss < best_loss :
save_model(model,args.save_path)
best_acc = acc
loss = acc
# plot and save training figs
plt.plot(train_acc)
plt.plot(val_acc)
plt.plot(train_acc)
plt.plot(train_acc)
plt.clf()
plt.subplot(2, 1, 1)
plt.plot(train_acc, label='train')
plt.plot(val_acc, label='val')
plt.title('Train and validation accuracy')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend(loc="upper left")
plt.ylim(0, 1.05)
plt.subplot(2, 1, 2)
plt.plot(train_loss, label='train')
plt.plot(val_loss, label='val')
plt.title('Train and validation loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(loc="upper left")
plt.show()
plt.savefig('output/training_plot_contrastive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
plt.savefig('../output/training_plot_contrastive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
#load and evaluate best model
load_model(model, args.save_path)
make_prediction_duo(model,data_test_batch, 'output/confusion_matrix_contractive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment