From e68172ea30c5ae50457e722785b29d205efd6603 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 15 Apr 2025 13:57:05 +0200 Subject: [PATCH] fix : config --- image_ref/main_ray.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index 8ebb066..81564a8 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -24,7 +24,7 @@ def train_model(config,args): batch_size=args.batch_size, ref_dir=args.dataset_ref_dir, noise_threshold=config['noise'], - positive_prop=config['positive_prop'], sampler=config['sampler']) + positive_prop=config['p_prop'], sampler=config['sampler']) # load model model = Classification_model_duo_contrastive(model=args.model, n_class=2) @@ -142,7 +142,7 @@ def train_model(config,args): checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) print(checkpoint.path) train.report( - {"train loss": losses_train, "train contrastive acc": acc_train,"val loss": losses_val,"val acc": acc_val,"val contrastive acc": acc_contrastive_val,}, + {"train loss": losses_train, "train cont acc": acc_train,"val loss": losses_val,"val acc": acc_val,"val cont acc": acc_contrastive_val,}, checkpoint=checkpoint,) print("Finished Training") @@ -155,7 +155,7 @@ def test_model(best_result, args): batch_size=args.batch_size, ref_dir=args.dataset_ref_dir, noise_threshold=best_result.config['noise'], - positive_prop=best_result.config['positive_prop'], sampler=best_result.config['sampler']) + positive_prop=best_result.config['p_prop'], sampler=best_result.config['sampler']) # load model model = Classification_model_duo_contrastive(model=args.model, n_class=2) @@ -213,13 +213,13 @@ def main(args, gpus_per_trial=1): config = { "lr": tune.loguniform(1e-4, 1e-2), "noise": tune.loguniform(1, 1000), - "positive_prop": tune.uniform(0, 100), + "p_prop": tune.uniform(5, 95), "optimizer": tune.choice(['Adam', 'SGD']), "sampler": tune.choice(['random', 'balanced']), } scheduler = ASHAScheduler( max_t=100, - grace_period=20, + grace_period=3, reduction_factor=3, brackets=1, ) -- GitLab