From 27b63b7247843b9abbb30840914ceaf7487db700 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Fri, 18 Apr 2025 12:50:57 +0200
Subject: [PATCH] fix : error device cuda

---
 image_ref/main_ray.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py
index d0e2360..2cb7285 100644
--- a/image_ref/main_ray.py
+++ b/image_ref/main_ray.py
@@ -88,6 +88,7 @@ def train_model(config,args):
 
             pred_class = torch.argmax(pred_logits, dim=1)
             acc += (pred_class == label).sum().item()
+            print(label.device,pred_logits.device)
             loss = loss_function(pred_logits, label)
             losses += loss.item()
             optimizer.zero_grad()
-- 
GitLab