Skip to content
Snippets Groups Projects
Commit e68172ea authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : config

parent eb0067b0
No related branches found
No related tags found
No related merge requests found
...@@ -24,7 +24,7 @@ def train_model(config,args): ...@@ -24,7 +24,7 @@ def train_model(config,args):
batch_size=args.batch_size, batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir, ref_dir=args.dataset_ref_dir,
noise_threshold=config['noise'], noise_threshold=config['noise'],
positive_prop=config['positive_prop'], sampler=config['sampler']) positive_prop=config['p_prop'], sampler=config['sampler'])
# load model # load model
model = Classification_model_duo_contrastive(model=args.model, n_class=2) model = Classification_model_duo_contrastive(model=args.model, n_class=2)
...@@ -142,7 +142,7 @@ def train_model(config,args): ...@@ -142,7 +142,7 @@ def train_model(config,args):
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
print(checkpoint.path) print(checkpoint.path)
train.report( train.report(
{"train loss": losses_train, "train contrastive acc": acc_train,"val loss": losses_val,"val acc": acc_val,"val contrastive acc": acc_contrastive_val,}, {"train loss": losses_train, "train cont acc": acc_train,"val loss": losses_val,"val acc": acc_val,"val cont acc": acc_contrastive_val,},
checkpoint=checkpoint,) checkpoint=checkpoint,)
print("Finished Training") print("Finished Training")
...@@ -155,7 +155,7 @@ def test_model(best_result, args): ...@@ -155,7 +155,7 @@ def test_model(best_result, args):
batch_size=args.batch_size, batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir, ref_dir=args.dataset_ref_dir,
noise_threshold=best_result.config['noise'], noise_threshold=best_result.config['noise'],
positive_prop=best_result.config['positive_prop'], sampler=best_result.config['sampler']) positive_prop=best_result.config['p_prop'], sampler=best_result.config['sampler'])
# load model # load model
model = Classification_model_duo_contrastive(model=args.model, n_class=2) model = Classification_model_duo_contrastive(model=args.model, n_class=2)
...@@ -213,13 +213,13 @@ def main(args, gpus_per_trial=1): ...@@ -213,13 +213,13 @@ def main(args, gpus_per_trial=1):
config = { config = {
"lr": tune.loguniform(1e-4, 1e-2), "lr": tune.loguniform(1e-4, 1e-2),
"noise": tune.loguniform(1, 1000), "noise": tune.loguniform(1, 1000),
"positive_prop": tune.uniform(0, 100), "p_prop": tune.uniform(5, 95),
"optimizer": tune.choice(['Adam', 'SGD']), "optimizer": tune.choice(['Adam', 'SGD']),
"sampler": tune.choice(['random', 'balanced']), "sampler": tune.choice(['random', 'balanced']),
} }
scheduler = ASHAScheduler( scheduler = ASHAScheduler(
max_t=100, max_t=100,
grace_period=20, grace_period=3,
reduction_factor=3, reduction_factor=3,
brackets=1, brackets=1,
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment