Skip to content
Snippets Groups Projects
Commit 4ee57bfe authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : error device cuda

parent 6edf5afa
No related branches found
No related tags found
No related merge requests found
......@@ -53,7 +53,8 @@ def train_model(config,args):
# init training
n_class = len(data_train.dataset.classes)
weight = torch.Tensor([1/n_class,1-1/n_class])
weight.to(device)
if torch.cuda.is_available():
weight = weight.cuda()
print('weight',weight.device)
loss_function = nn.CrossEntropyLoss(weight=weight)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment