From 4ee57bfec4e9d35cd708125ceea26b8ad1a90bba Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Fri, 18 Apr 2025 13:07:03 +0200
Subject: [PATCH] fix : error device cuda

---
 image_ref/main_ray.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py
index 3530a8b..87011bb 100644
--- a/image_ref/main_ray.py
+++ b/image_ref/main_ray.py
@@ -53,7 +53,8 @@ def train_model(config,args):
     # init training
     n_class = len(data_train.dataset.classes)
     weight = torch.Tensor([1/n_class,1-1/n_class])
-    weight.to(device)
+    if torch.cuda.is_available():
+        weight = weight.cuda()
     print('weight',weight.device)
     loss_function = nn.CrossEntropyLoss(weight=weight)
 
-- 
GitLab