Skip to content
Snippets Groups Projects
Commit 516d4532 authored by Schneider Leo's avatar Schneider Leo
Browse files

remove : debugging print

parent 3b64549d
No related branches found
No related tags found
No related merge requests found
...@@ -56,7 +56,6 @@ def train_model(config,args): ...@@ -56,7 +56,6 @@ def train_model(config,args):
weight.float() weight.float()
if torch.cuda.is_available(): if torch.cuda.is_available():
weight = weight.cuda() weight = weight.cuda()
print('weight',weight.device)
loss_function = nn.CrossEntropyLoss(weight=weight) loss_function = nn.CrossEntropyLoss(weight=weight)
# Load existing checkpoint through `get_checkpoint()` API. # Load existing checkpoint through `get_checkpoint()` API.
...@@ -94,7 +93,6 @@ def train_model(config,args): ...@@ -94,7 +93,6 @@ def train_model(config,args):
pred_class = torch.argmax(pred_logits, dim=1) pred_class = torch.argmax(pred_logits, dim=1)
acc += (pred_class == label).sum().item() acc += (pred_class == label).sum().item()
print(label.device,pred_logits.device)
loss = loss_function(pred_logits, label) loss = loss_function(pred_logits, label)
losses += loss.item() losses += loss.item()
optimizer.zero_grad() optimizer.zero_grad()
...@@ -195,9 +193,7 @@ def test_model(best_result, args): ...@@ -195,9 +193,7 @@ def test_model(best_result, args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda:0" device = "cuda:0"
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
print(type(model))
net = torch.nn.DataParallel(model) net = torch.nn.DataParallel(model)
print(type(net))
model.to(device) model.to(device)
# init training # init training
loss_function = nn.CrossEntropyLoss() loss_function = nn.CrossEntropyLoss()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment