Skip to content
Snippets Groups Projects
Commit 6edf5afa authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : error device cuda

parent 27b63b72
No related branches found
No related tags found
No related merge requests found
...@@ -54,6 +54,7 @@ def train_model(config,args): ...@@ -54,6 +54,7 @@ def train_model(config,args):
n_class = len(data_train.dataset.classes) n_class = len(data_train.dataset.classes)
weight = torch.Tensor([1/n_class,1-1/n_class]) weight = torch.Tensor([1/n_class,1-1/n_class])
weight.to(device) weight.to(device)
print('weight',weight.device)
loss_function = nn.CrossEntropyLoss(weight=weight) loss_function = nn.CrossEntropyLoss(weight=weight)
# Load existing checkpoint through `get_checkpoint()` API. # Load existing checkpoint through `get_checkpoint()` API.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment