diff --git a/main.py b/main.py index 64578b75a7bfd3f703f1bf368dbde9ffbf6766a6..cd8556e94654b87dc3fbba68d36479e7e4ff12a8 100644 --- a/main.py +++ b/main.py @@ -73,23 +73,23 @@ def run(args): loss_function = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) - # for e in range(args.epoches): - # loss, acc = train(model,data_train,optimizer,loss_function,e) - # train_loss.append(loss) - # train_acc.append(acc) - # if e%args.eval_inter==0 : - # loss, acc = test(model,data_test,loss_function,e) - # val_loss.append(loss) - # val_acc.append(acc) - # if acc > best_acc : - # save_model(model,args.save_path) - # best_acc = acc - # plt.plot(train_acc) - # plt.plot(val_acc) - # plt.plot(train_acc) - # plt.plot(train_acc) - # plt.show() - # plt.savefig('output/training_plot.png') + for e in range(args.epoches): + loss, acc = train(model,data_train,optimizer,loss_function,e) + train_loss.append(loss) + train_acc.append(acc) + if e%args.eval_inter==0 : + loss, acc = test(model,data_test,loss_function,e) + val_loss.append(loss) + val_acc.append(acc) + if acc > best_acc : + save_model(model,args.save_path) + best_acc = acc + plt.plot(train_acc) + plt.plot(val_acc) + plt.plot(train_acc) + plt.plot(train_acc) + plt.show() + plt.savefig('output/training_plot.png') load_model(model, args.save_path) make_prediction(model,data_test)