From 4efef23b65355269a4f864ac970c19e39a2aad07 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Mon, 10 Mar 2025 14:07:52 +0100
Subject: [PATCH] model cuda loading

---
 main.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/main.py b/main.py
index f38d050..2ad0475 100644
--- a/main.py
+++ b/main.py
@@ -57,6 +57,8 @@ def test(model, data_test, loss_function, epoch):
 
 def run(args):
     model = Classification_model(n_class=9)
+    if torch.cuda.is_available():
+        model = model.cuda()
     best_acc = 0
     train_acc=[]
     train_loss=[]
-- 
GitLab