diff --git a/main_ray.py b/main_ray.py
index 86cf0874a426c4f5728ca862307c8c2b95572e1d..a45202e864f59f5f7ce4ab4184fef1f4f6b40f45 100644
--- a/main_ray.py
+++ b/main_ray.py
@@ -29,7 +29,7 @@ def train_model(config,args):
     model = Classification_model_duo(model=args.model, n_class=len(data_train.dataset.classes))
 
     # move parameters to GPU
-    model.double()
+    model.float()
     device = "cpu"
     if torch.cuda.is_available():
         device = "cuda:0"
@@ -45,7 +45,8 @@ def train_model(config,args):
     elif config['loss']== 'weighed':
         classes_numbers = [51, 12, 9, 10, 86, 231, 20, 13, 24, 96, 11, 39, 11]
         loss_weights = torch.tensor([1/n for n in classes_numbers])
-        loss_weights.to(device)
+        if torch.cuda.is_available():
+            loss_weights = loss_weights.cuda()
         loss_function = nn.CrossEntropyLoss(loss_weights)
     # Load existing checkpoint through `get_checkpoint()` API.
     if train.get_checkpoint():
@@ -69,6 +70,8 @@ def train_model(config,args):
 
         for imaer, imana, label in data_train:
             label = label.long()
+            imaer = imaer.float()
+            imana = imana.float()
             if torch.cuda.is_available():
                 imaer = imaer.cuda()
                 imana = imana.cuda()
@@ -92,6 +95,8 @@ def train_model(config,args):
             param.requires_grad = False
 
         for imaer, imana, label in data_test:
+            imaer = imaer.float()
+            imana = imana.float()
             label = label.long()
             if torch.cuda.is_available():
                 imaer = imaer.cuda()
@@ -140,7 +145,7 @@ def test_model(best_result, args):
 
     # load model
     model = Classification_model_duo(model=args.model, n_class=len(data_test.dataset.classes))
-    model.double()
+    model.float()
     # load weight
     checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
 
@@ -165,6 +170,8 @@ def test_model(best_result, args):
         param.requires_grad = False
 
     for imaer, imana, label in data_test:
+        imaer = imaer.float()
+        imana = imana.float()
         label = label.long()
         if torch.cuda.is_available():
             imaer = imaer.cuda()