From 5eca76523cfc1ecca360b1cab851bd5c44ad100f Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 15 Apr 2025 17:28:41 +0200 Subject: [PATCH] add : weighted val loss --- image_ref/main_ray.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index d4eac837..123286fb 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -45,7 +45,8 @@ def train_model(config,args): elif config['optimizer']=='SGD' : optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9) # init training - loss_function = nn.CrossEntropyLoss() + n_class = data_train.dataset.classes + loss_function = nn.CrossEntropyLoss(weight=torch.Tensor([1/n_class,1-1/n_class])) # Load existing checkpoint through `get_checkpoint()` API. if train.get_checkpoint(): @@ -240,7 +241,7 @@ def main(args, gpus_per_trial=1): ), run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/ray_results_test", - name="base_experiment" + name="weight_val_loss_experiment" ), param_space=config -- GitLab