diff --git a/main_ray_tune.py b/main_ray_tune.py
index 50036dc1a665f38b9ae3041365521bff4498ceaa..0ff6221d836adb3e082a082146b04bce94eb7b53 100644
--- a/main_ray_tune.py
+++ b/main_ray_tune.py
@@ -199,7 +199,9 @@ def test_best_model(best_result, args):
     if torch.cuda.is_available():
         device = "cuda:0"
         if torch.cuda.device_count() > 1:
+            print(print(type(best_trained_model.module)))
             best_trained_model = torch.nn.DataParallel(best_trained_model)
+            print(print(type(best_trained_model.module)))
 
     best_trained_model.to(device)
     criterion_rt = torch.nn.MSELoss()