From 956b97c402257e5768b92d7fbded7d7fda9d9475 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 22 Apr 2025 17:25:25 +0200 Subject: [PATCH] add : wandb random name --- image_ref/config.py | 2 +- image_ref/hyperparameter_res_analysis.py | 2 -- image_ref/main.py | 4 ++-- image_ref/main_ray.py | 1 + 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/image_ref/config.py b/image_ref/config.py index 95bd541..817e0cd 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -22,7 +22,7 @@ def load_args_contrastive(): parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') parser.add_argument('--pretrain_path', type=str, default=None) - parser.add_argument('--wandb', type=str, default='image_ref/best_model_base_ray.pt') + parser.add_argument('--wandb', type=str, default='wandb_run') args = parser.parse_args() return args \ No newline at end of file diff --git a/image_ref/hyperparameter_res_analysis.py b/image_ref/hyperparameter_res_analysis.py index cdabac5..7a411a3 100644 --- a/image_ref/hyperparameter_res_analysis.py +++ b/image_ref/hyperparameter_res_analysis.py @@ -2,5 +2,3 @@ import pandas as pd import numpy as np df = pd.read_csv('../df_results_contrastive.csv') - -best_param = df[df['val loss']<0.003] \ No newline at end of file diff --git a/image_ref/main.py b/image_ref/main.py index bab6e98..32cba8f 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -101,8 +101,8 @@ def run_duo(args): os.environ["WANDB_MODE"] = "offline" os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run") - wdb.init(project="contrastive_classification", dir='./wandb_run', name=args.wandb) - + wdb.init(project="contrastive_classification", dir='./wandb_run') + print('Wandb initialised') # load data data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_val=args.dataset_val_dir, diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index 80fa177..801ffc2 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -50,6 +50,7 @@ def train_model(config,args): optimizer = optim.Adam(model.parameters(), lr=config["lr"]) elif config['optimizer']=='SGD' : optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9) + # init training n_class = len(data_train.dataset.classes) weight = torch.Tensor([1/n_class,1-1/n_class]) -- GitLab