From 4ee57bfec4e9d35cd708125ceea26b8ad1a90bba Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Fri, 18 Apr 2025 13:07:03 +0200 Subject: [PATCH] fix : error device cuda --- image_ref/main_ray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index 3530a8b..87011bb 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -53,7 +53,8 @@ def train_model(config,args): # init training n_class = len(data_train.dataset.classes) weight = torch.Tensor([1/n_class,1-1/n_class]) - weight.to(device) + if torch.cuda.is_available(): + weight = weight.cuda() print('weight',weight.device) loss_function = nn.CrossEntropyLoss(weight=weight) -- GitLab