From dfd088dee7e35fcdb9c5f187c9f06f0dfecde7fa Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Mon, 10 Mar 2025 15:23:53 +0100
Subject: [PATCH] model cuda loading

---
 main.py | 34 +++++++++++++++++-----------------
 1 file changed, 17 insertions(+), 17 deletions(-)

diff --git a/main.py b/main.py
index 64578b7..cd8556e 100644
--- a/main.py
+++ b/main.py
@@ -73,23 +73,23 @@ def run(args):
     loss_function = nn.CrossEntropyLoss()
     optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
-    # for e in range(args.epoches):
-    #     loss, acc = train(model,data_train,optimizer,loss_function,e)
-    #     train_loss.append(loss)
-    #     train_acc.append(acc)
-    #     if e%args.eval_inter==0 :
-    #         loss, acc = test(model,data_test,loss_function,e)
-    #         val_loss.append(loss)
-    #         val_acc.append(acc)
-    #         if acc > best_acc :
-    #             save_model(model,args.save_path)
-    #             best_acc = acc
-    # plt.plot(train_acc)
-    # plt.plot(val_acc)
-    # plt.plot(train_acc)
-    # plt.plot(train_acc)
-    # plt.show()
-    # plt.savefig('output/training_plot.png')
+    for e in range(args.epoches):
+        loss, acc = train(model,data_train,optimizer,loss_function,e)
+        train_loss.append(loss)
+        train_acc.append(acc)
+        if e%args.eval_inter==0 :
+            loss, acc = test(model,data_test,loss_function,e)
+            val_loss.append(loss)
+            val_acc.append(acc)
+            if acc > best_acc :
+                save_model(model,args.save_path)
+                best_acc = acc
+    plt.plot(train_acc)
+    plt.plot(val_acc)
+    plt.plot(train_acc)
+    plt.plot(train_acc)
+    plt.show()
+    plt.savefig('output/training_plot.png')
 
     load_model(model, args.save_path)
     make_prediction(model,data_test)
-- 
GitLab