From 956b97c402257e5768b92d7fbded7d7fda9d9475 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 22 Apr 2025 17:25:25 +0200
Subject: [PATCH] add : wandb random name

---
 image_ref/config.py                      | 2 +-
 image_ref/hyperparameter_res_analysis.py | 2 --
 image_ref/main.py                        | 4 ++--
 image_ref/main_ray.py                    | 1 +
 4 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/image_ref/config.py b/image_ref/config.py
index 95bd541..817e0cd 100644
--- a/image_ref/config.py
+++ b/image_ref/config.py
@@ -22,7 +22,7 @@ def load_args_contrastive():
     parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
     parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
     parser.add_argument('--pretrain_path', type=str, default=None)
-    parser.add_argument('--wandb', type=str, default='image_ref/best_model_base_ray.pt')
+    parser.add_argument('--wandb', type=str, default='wandb_run')
     args = parser.parse_args()
 
     return args
\ No newline at end of file
diff --git a/image_ref/hyperparameter_res_analysis.py b/image_ref/hyperparameter_res_analysis.py
index cdabac5..7a411a3 100644
--- a/image_ref/hyperparameter_res_analysis.py
+++ b/image_ref/hyperparameter_res_analysis.py
@@ -2,5 +2,3 @@ import pandas as pd
 import numpy as np
 
 df = pd.read_csv('../df_results_contrastive.csv')
-
-best_param = df[df['val loss']<0.003]
\ No newline at end of file
diff --git a/image_ref/main.py b/image_ref/main.py
index bab6e98..32cba8f 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -101,8 +101,8 @@ def run_duo(args):
         os.environ["WANDB_MODE"] = "offline"
         os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
 
-        wdb.init(project="contrastive_classification", dir='./wandb_run', name=args.wandb)
-
+        wdb.init(project="contrastive_classification", dir='./wandb_run')
+        print('Wandb initialised')
     # load data
     data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir,
                                                                 base_dir_val=args.dataset_val_dir,
diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py
index 80fa177..801ffc2 100644
--- a/image_ref/main_ray.py
+++ b/image_ref/main_ray.py
@@ -50,6 +50,7 @@ def train_model(config,args):
         optimizer = optim.Adam(model.parameters(), lr=config["lr"])
     elif config['optimizer']=='SGD' :
         optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9)
+
     # init training
     n_class = len(data_train.dataset.classes)
     weight = torch.Tensor([1/n_class,1-1/n_class])
-- 
GitLab