diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index bd3aa6a4cdaac12564cf65891d533859f2620e82..b06aa9bf2f1cda13305b876ab8c8f58787f1b8cf 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -18,11 +18,19 @@ from ray.tune.schedulers import ASHAScheduler def train_model(config,args): # load data + + if config['res_count_thr']=='none': + ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref' + elif config['res_count_thr']=='10': + ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_10' + else : + ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_5' + data_train, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_val=args.dataset_val_dir, base_dir_test=args.dataset_test_dir, batch_size=args.batch_size, - ref_dir=args.dataset_ref_dir, + ref_dir=ref_dir, noise_threshold=config['noise'], positive_prop=config['p_prop'], sampler=config['sampler']) @@ -149,12 +157,20 @@ def train_model(config,args): def test_model(best_result, args): + + if best_result.config['res_count_thr']=='none': + ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref' + elif best_result.config['res_count_thr']=='10': + ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_10' + else : + ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_5' + # load data _, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_val=args.dataset_val_dir, base_dir_test=args.dataset_test_dir, batch_size=args.batch_size, - ref_dir=args.dataset_ref_dir, + ref_dir=ref_dir, noise_threshold=best_result.config['noise'], positive_prop=best_result.config['p_prop'], sampler=best_result.config['sampler']) @@ -217,6 +233,7 @@ def main(args, gpus_per_trial=1): "p_prop": tune.uniform(5, 95), "optimizer": tune.choice(['Adam', 'SGD']), #adam plus efficace ? "sampler": tune.choice(['random', 'balanced']), + "ref_count_thr" : tune.choice(['none', '10', '5']) } scheduler = ASHAScheduler( max_t=100, @@ -235,13 +252,13 @@ def main(args, gpus_per_trial=1): time_budget_s=3600 * 19.5, search_alg=algo, scheduler=scheduler, - num_samples=50, + num_samples=-1, metric="val loss", mode='min', ), run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/ray_results_test", - name="weight_val_loss_experiment" + name="ref_count_threshold_experiment" ), param_space=config