Skip to content
Snippets Groups Projects
Commit 5eca7652 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : weighted val loss

parent e585f394
No related branches found
No related tags found
No related merge requests found
......@@ -45,7 +45,8 @@ def train_model(config,args):
elif config['optimizer']=='SGD' :
optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9)
# init training
loss_function = nn.CrossEntropyLoss()
n_class = data_train.dataset.classes
loss_function = nn.CrossEntropyLoss(weight=torch.Tensor([1/n_class,1-1/n_class]))
# Load existing checkpoint through `get_checkpoint()` API.
if train.get_checkpoint():
......@@ -240,7 +241,7 @@ def main(args, gpus_per_trial=1):
),
run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/ray_results_test",
name="base_experiment"
name="weight_val_loss_experiment"
),
param_space=config
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment