diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py
index 3530a8b46cfb67dd8d0f13d18a449e07218e9a48..87011bbea56e33901b08d5ed90253813d624237d 100644
--- a/image_ref/main_ray.py
+++ b/image_ref/main_ray.py
@@ -53,7 +53,8 @@ def train_model(config,args):
     # init training
     n_class = len(data_train.dataset.classes)
     weight = torch.Tensor([1/n_class,1-1/n_class])
-    weight.to(device)
+    if torch.cuda.is_available():
+        weight = weight.cuda()
     print('weight',weight.device)
     loss_function = nn.CrossEntropyLoss(weight=weight)