diff --git a/main.py b/main.py index cd8556e94654b87dc3fbba68d36479e7e4ff12a8..64578b75a7bfd3f703f1bf368dbde9ffbf6766a6 100644 --- a/main.py +++ b/main.py @@ -73,23 +73,23 @@ def run(args): loss_function = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) - for e in range(args.epoches): - loss, acc = train(model,data_train,optimizer,loss_function,e) - train_loss.append(loss) - train_acc.append(acc) - if e%args.eval_inter==0 : - loss, acc = test(model,data_test,loss_function,e) - val_loss.append(loss) - val_acc.append(acc) - if acc > best_acc : - save_model(model,args.save_path) - best_acc = acc - plt.plot(train_acc) - plt.plot(val_acc) - plt.plot(train_acc) - plt.plot(train_acc) - plt.show() - plt.savefig('output/training_plot.png') + # for e in range(args.epoches): + # loss, acc = train(model,data_train,optimizer,loss_function,e) + # train_loss.append(loss) + # train_acc.append(acc) + # if e%args.eval_inter==0 : + # loss, acc = test(model,data_test,loss_function,e) + # val_loss.append(loss) + # val_acc.append(acc) + # if acc > best_acc : + # save_model(model,args.save_path) + # best_acc = acc + # plt.plot(train_acc) + # plt.plot(val_acc) + # plt.plot(train_acc) + # plt.plot(train_acc) + # plt.show() + # plt.savefig('output/training_plot.png') load_model(model, args.save_path) make_prediction(model,data_test)