From 9d1bc5abde614d91653f5114ff013a9a728278aa Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 10 Mar 2025 15:12:51 +0100 Subject: [PATCH] model cuda loading --- main.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/main.py b/main.py index cd8556e..64578b7 100644 --- a/main.py +++ b/main.py @@ -73,23 +73,23 @@ def run(args): loss_function = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) - for e in range(args.epoches): - loss, acc = train(model,data_train,optimizer,loss_function,e) - train_loss.append(loss) - train_acc.append(acc) - if e%args.eval_inter==0 : - loss, acc = test(model,data_test,loss_function,e) - val_loss.append(loss) - val_acc.append(acc) - if acc > best_acc : - save_model(model,args.save_path) - best_acc = acc - plt.plot(train_acc) - plt.plot(val_acc) - plt.plot(train_acc) - plt.plot(train_acc) - plt.show() - plt.savefig('output/training_plot.png') + # for e in range(args.epoches): + # loss, acc = train(model,data_train,optimizer,loss_function,e) + # train_loss.append(loss) + # train_acc.append(acc) + # if e%args.eval_inter==0 : + # loss, acc = test(model,data_test,loss_function,e) + # val_loss.append(loss) + # val_acc.append(acc) + # if acc > best_acc : + # save_model(model,args.save_path) + # best_acc = acc + # plt.plot(train_acc) + # plt.plot(val_acc) + # plt.plot(train_acc) + # plt.plot(train_acc) + # plt.show() + # plt.savefig('output/training_plot.png') load_model(model, args.save_path) make_prediction(model,data_test) -- GitLab