From 5eca76523cfc1ecca360b1cab851bd5c44ad100f Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 15 Apr 2025 17:28:41 +0200
Subject: [PATCH] add : weighted val loss

---
 image_ref/main_ray.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py
index d4eac837..123286fb 100644
--- a/image_ref/main_ray.py
+++ b/image_ref/main_ray.py
@@ -45,7 +45,8 @@ def train_model(config,args):
     elif config['optimizer']=='SGD' :
         optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9)
     # init training
-    loss_function = nn.CrossEntropyLoss()
+    n_class = data_train.dataset.classes
+    loss_function = nn.CrossEntropyLoss(weight=torch.Tensor([1/n_class,1-1/n_class]))
 
     # Load existing checkpoint through `get_checkpoint()` API.
     if train.get_checkpoint():
@@ -240,7 +241,7 @@ def main(args, gpus_per_trial=1):
 
         ),
         run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/ray_results_test",
-                             name="base_experiment"
+                             name="weight_val_loss_experiment"
                              ),
         param_space=config
 
-- 
GitLab