diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py
index 03d901645c50783fe8d51495e3e4b40870b5aaea..d0e236069decfd91945aa9ec9c2e652952febc9a 100644
--- a/image_ref/main_ray.py
+++ b/image_ref/main_ray.py
@@ -83,9 +83,6 @@ def train_model(config,args):
                 imana = imana.cuda()
                 img_ref = img_ref.cuda()
                 label = label.cuda()
-            if torch.cuda.device_count() > 1:
-                pred_logits = model.module.forward(imaer, imana, img_ref)
-            else:
                 pred_logits = model.forward(imaer, imana, img_ref)