diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py
index d0e236069decfd91945aa9ec9c2e652952febc9a..2cb7285312d8223b6cc7f8c4826baa1b5aa33193 100644
--- a/image_ref/main_ray.py
+++ b/image_ref/main_ray.py
@@ -88,6 +88,7 @@ def train_model(config,args):
 
             pred_class = torch.argmax(pred_logits, dim=1)
             acc += (pred_class == label).sum().item()
+            print(label.device,pred_logits.device)
             loss = loss_function(pred_logits, label)
             losses += loss.item()
             optimizer.zero_grad()