From 6edf5afae4045d1b61082b5393e750fa6a6a2139 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Fri, 18 Apr 2025 13:02:39 +0200
Subject: [PATCH] fix : error device cuda

---
 image_ref/main_ray.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py
index 2cb7285..3530a8b 100644
--- a/image_ref/main_ray.py
+++ b/image_ref/main_ray.py
@@ -54,6 +54,7 @@ def train_model(config,args):
     n_class = len(data_train.dataset.classes)
     weight = torch.Tensor([1/n_class,1-1/n_class])
     weight.to(device)
+    print('weight',weight.device)
     loss_function = nn.CrossEntropyLoss(weight=weight)
 
     # Load existing checkpoint through `get_checkpoint()` API.
-- 
GitLab