From 6edf5afae4045d1b61082b5393e750fa6a6a2139 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Fri, 18 Apr 2025 13:02:39 +0200 Subject: [PATCH] fix : error device cuda --- image_ref/main_ray.py | 1 + 1 file changed, 1 insertion(+) diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index 2cb7285..3530a8b 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -54,6 +54,7 @@ def train_model(config,args): n_class = len(data_train.dataset.classes) weight = torch.Tensor([1/n_class,1-1/n_class]) weight.to(device) + print('weight',weight.device) loss_function = nn.CrossEntropyLoss(weight=weight) # Load existing checkpoint through `get_checkpoint()` API. -- GitLab