From 27b63b7247843b9abbb30840914ceaf7487db700 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Fri, 18 Apr 2025 12:50:57 +0200 Subject: [PATCH] fix : error device cuda --- image_ref/main_ray.py | 1 + 1 file changed, 1 insertion(+) diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index d0e2360..2cb7285 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -88,6 +88,7 @@ def train_model(config,args): pred_class = torch.argmax(pred_logits, dim=1) acc += (pred_class == label).sum().item() + print(label.device,pred_logits.device) loss = loss_function(pred_logits, label) losses += loss.item() optimizer.zero_grad() -- GitLab