diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index 3530a8b46cfb67dd8d0f13d18a449e07218e9a48..87011bbea56e33901b08d5ed90253813d624237d 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -53,7 +53,8 @@ def train_model(config,args): # init training n_class = len(data_train.dataset.classes) weight = torch.Tensor([1/n_class,1-1/n_class]) - weight.to(device) + if torch.cuda.is_available(): + weight = weight.cuda() print('weight',weight.device) loss_function = nn.CrossEntropyLoss(weight=weight)