Skip to content
Snippets Groups Projects
Commit 06a7c68e authored by Schneider Leo's avatar Schneider Leo
Browse files

change : ray conf

parent cdf30a65
No related branches found
No related tags found
No related merge requests found
...@@ -18,11 +18,19 @@ from ray.tune.schedulers import ASHAScheduler ...@@ -18,11 +18,19 @@ from ray.tune.schedulers import ASHAScheduler
def train_model(config,args): def train_model(config,args):
# load data # load data
if config['res_count_thr']=='none':
ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref'
elif config['res_count_thr']=='10':
ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_10'
else :
ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_5'
data_train, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir, data_train, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir,
base_dir_val=args.dataset_val_dir, base_dir_val=args.dataset_val_dir,
base_dir_test=args.dataset_test_dir, base_dir_test=args.dataset_test_dir,
batch_size=args.batch_size, batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir, ref_dir=ref_dir,
noise_threshold=config['noise'], noise_threshold=config['noise'],
positive_prop=config['p_prop'], sampler=config['sampler']) positive_prop=config['p_prop'], sampler=config['sampler'])
...@@ -149,12 +157,20 @@ def train_model(config,args): ...@@ -149,12 +157,20 @@ def train_model(config,args):
def test_model(best_result, args): def test_model(best_result, args):
if best_result.config['res_count_thr']=='none':
ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref'
elif best_result.config['res_count_thr']=='10':
ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_10'
else :
ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_5'
# load data # load data
_, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir, _, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir,
base_dir_val=args.dataset_val_dir, base_dir_val=args.dataset_val_dir,
base_dir_test=args.dataset_test_dir, base_dir_test=args.dataset_test_dir,
batch_size=args.batch_size, batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir, ref_dir=ref_dir,
noise_threshold=best_result.config['noise'], noise_threshold=best_result.config['noise'],
positive_prop=best_result.config['p_prop'], sampler=best_result.config['sampler']) positive_prop=best_result.config['p_prop'], sampler=best_result.config['sampler'])
...@@ -217,6 +233,7 @@ def main(args, gpus_per_trial=1): ...@@ -217,6 +233,7 @@ def main(args, gpus_per_trial=1):
"p_prop": tune.uniform(5, 95), "p_prop": tune.uniform(5, 95),
"optimizer": tune.choice(['Adam', 'SGD']), #adam plus efficace ? "optimizer": tune.choice(['Adam', 'SGD']), #adam plus efficace ?
"sampler": tune.choice(['random', 'balanced']), "sampler": tune.choice(['random', 'balanced']),
"ref_count_thr" : tune.choice(['none', '10', '5'])
} }
scheduler = ASHAScheduler( scheduler = ASHAScheduler(
max_t=100, max_t=100,
...@@ -235,13 +252,13 @@ def main(args, gpus_per_trial=1): ...@@ -235,13 +252,13 @@ def main(args, gpus_per_trial=1):
time_budget_s=3600 * 19.5, time_budget_s=3600 * 19.5,
search_alg=algo, search_alg=algo,
scheduler=scheduler, scheduler=scheduler,
num_samples=50, num_samples=-1,
metric="val loss", metric="val loss",
mode='min', mode='min',
), ),
run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/ray_results_test", run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/ray_results_test",
name="weight_val_loss_experiment" name="ref_count_threshold_experiment"
), ),
param_space=config param_space=config
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment