Skip to content
Snippets Groups Projects
Commit e68172ea authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : config

parent eb0067b0
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ def train_model(config,args):
batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir,
noise_threshold=config['noise'],
positive_prop=config['positive_prop'], sampler=config['sampler'])
positive_prop=config['p_prop'], sampler=config['sampler'])
# load model
model = Classification_model_duo_contrastive(model=args.model, n_class=2)
......@@ -142,7 +142,7 @@ def train_model(config,args):
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
print(checkpoint.path)
train.report(
{"train loss": losses_train, "train contrastive acc": acc_train,"val loss": losses_val,"val acc": acc_val,"val contrastive acc": acc_contrastive_val,},
{"train loss": losses_train, "train cont acc": acc_train,"val loss": losses_val,"val acc": acc_val,"val cont acc": acc_contrastive_val,},
checkpoint=checkpoint,)
print("Finished Training")
......@@ -155,7 +155,7 @@ def test_model(best_result, args):
batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir,
noise_threshold=best_result.config['noise'],
positive_prop=best_result.config['positive_prop'], sampler=best_result.config['sampler'])
positive_prop=best_result.config['p_prop'], sampler=best_result.config['sampler'])
# load model
model = Classification_model_duo_contrastive(model=args.model, n_class=2)
......@@ -213,13 +213,13 @@ def main(args, gpus_per_trial=1):
config = {
"lr": tune.loguniform(1e-4, 1e-2),
"noise": tune.loguniform(1, 1000),
"positive_prop": tune.uniform(0, 100),
"p_prop": tune.uniform(5, 95),
"optimizer": tune.choice(['Adam', 'SGD']),
"sampler": tune.choice(['random', 'balanced']),
}
scheduler = ASHAScheduler(
max_t=100,
grace_period=20,
grace_period=3,
reduction_factor=3,
brackets=1,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment