Skip to content
Snippets Groups Projects
Commit 27b63b72 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : error device cuda

parent ef606767
No related branches found
No related tags found
No related merge requests found
......@@ -88,6 +88,7 @@ def train_model(config,args):
pred_class = torch.argmax(pred_logits, dim=1)
acc += (pred_class == label).sum().item()
print(label.device,pred_logits.device)
loss = loss_function(pred_logits, label)
losses += loss.item()
optimizer.zero_grad()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment