diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index 03d901645c50783fe8d51495e3e4b40870b5aaea..d0e236069decfd91945aa9ec9c2e652952febc9a 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -83,9 +83,6 @@ def train_model(config,args): imana = imana.cuda() img_ref = img_ref.cuda() label = label.cuda() - if torch.cuda.device_count() > 1: - pred_logits = model.module.forward(imaer, imana, img_ref) - else: pred_logits = model.forward(imaer, imana, img_ref)