diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index d0e236069decfd91945aa9ec9c2e652952febc9a..2cb7285312d8223b6cc7f8c4826baa1b5aa33193 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -88,6 +88,7 @@ def train_model(config,args): pred_class = torch.argmax(pred_logits, dim=1) acc += (pred_class == label).sum().item() + print(label.device,pred_logits.device) loss = loss_function(pred_logits, label) losses += loss.item() optimizer.zero_grad()