diff --git a/image_ref/main_sweep.py b/image_ref/main_sweep.py index 64fb7c7a550b39ebc693e0f512c14871c6436354..96d46f2b21e5540c0738d6aa08a7970ace868798 100644 --- a/image_ref/main_sweep.py +++ b/image_ref/main_sweep.py @@ -1,6 +1,6 @@ -import time - -import wandb as wdb +import random +import numpy as np +from sweep_train import run_duo if __name__ == '__main__': sweep_configuration = { @@ -10,7 +10,7 @@ if __name__ == '__main__': "parameters": { "epoches":{"value": 50}, "eval_inter":{"value": 1}, - "noise_threshold": {"distribution" : "log_uniform_values", "max": 10000., "min": 0.0001}, + "noise_threshold": {"distribution" : "log_uniform_values", "max": 10000., "min": 1}, "lr": {"distribution" : "log_uniform_values", "max": 0.01, "min": 0.0001}, "batch_size": {"value": 64}, "positive_prop": {"distribution" : "uniform","max": 95., "min": 5.}, @@ -21,15 +21,21 @@ if __name__ == '__main__': "dataset_val_dir": {"value": "data/processed_data_wiff/npy_image/test_data"}, "dataset_ref_dir": {"values": ["image_ref/img_ref","image_ref/img_ref_count_th_10","image_ref/img_ref_count_th_5"]}, }, - "controller":{ - "type": "local"}, + "max_iter": 10 } - sweep_id = wdb.sweep(sweep_configuration) + for i in range(sweep_configuration["max_iter"]): + run_config={} + for p,v in sweep_configuration["parameters"].items() : - # Start the local controller - sweep = wdb.controller(sweep_id) - while not sweep.done(): - sweep.print_status() - sweep.step() - time.sleep(5) + if "value" in v: + run_config[p]=v["value"] + elif "values" in v: + run_config[p] = random.choice(v["values"]) + elif "distribution" in v: + if v["distribution"]=="uniform": + run_config[p] = random.uniform(v["min"],v["max"]) + elif v["distribution"]=="log_uniform_values": + run_config[p] = np.exp(random.uniform(np.log(v["min"]), np.log(v["max"]))) + print('Launching run') + run_duo(run_config) diff --git a/image_ref/sweep_train.py b/image_ref/sweep_train.py index 056df14c20e5ee20512ab4165a15959539e3024a..fd544fe21261c453228e25b18936746347697b83 100644 --- a/image_ref/sweep_train.py +++ b/image_ref/sweep_train.py @@ -6,7 +6,7 @@ import torch.nn as nn from model import Classification_model_duo_contrastive import torch.optim as optim -def train_duo(model, data_train, optimizer, loss_function, epoch, wandb): +def train_duo(model, data_train, optimizer, loss_function, epoch): model.train() losses = 0. acc = 0. @@ -40,7 +40,7 @@ def train_duo(model, data_train, optimizer, loss_function, epoch, wandb): return losses, acc -def val_duo(model, data_test, loss_function, epoch, wandb): +def val_duo(model, data_test, loss_function, epoch): model.eval() losses = 0. acc = 0. @@ -94,15 +94,15 @@ def run_duo(args): print('Wandb initialised') # load data - data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, - base_dir_val=args.dataset_val_dir, + data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args['dataset_train_dir'], + base_dir_val=args['dataset_val_dir'], base_dir_test=None, - batch_size=args.batch_size, - ref_dir=args.dataset_ref_dir, - positive_prop=args.positive_prop, sampler=args.sampler) + batch_size=args['batch_size'], + ref_dir=args['dataset_ref_dir'], + positive_prop=args['positive_prop'], sampler=args['sampler']) # load model - model = Classification_model_duo_contrastive(model=args.model, n_class=2) + model = Classification_model_duo_contrastive(model=args['model'], n_class=2) model.float() # move parameters to GPU if torch.cuda.is_available(): @@ -118,15 +118,15 @@ def run_duo(args): val_loss = [] # init training loss_function = nn.CrossEntropyLoss() - if args.opti == 'adam': - optimizer = optim.Adam(model.parameters(), lr=args.lr) + if args['opti'] == 'adam': + optimizer = optim.Adam(model.parameters(), lr=args['lr']) # train model - for e in range(args.epoches): - loss, acc = train_duo(model, data_train, optimizer, loss_function, e, args.wandb) + for e in range(args['epoches']): + loss, acc = train_duo(model, data_train, optimizer, loss_function, e) train_loss.append(loss) train_acc.append(acc) - if e % args.eval_inter == 0: - loss, acc, acc_contrastive = val_duo(model, data_val_batch, loss_function, e, args.wandb) + if e % args['eval_inter'] == 0: + loss, acc, acc_contrastive = val_duo(model, data_val_batch, loss_function, e) val_loss.append(loss) val_acc.append(acc) val_cont_acc.append(acc_contrastive) @@ -134,6 +134,4 @@ def run_duo(args): if __name__ == '__main__': - config = wdb.config - print(config) - run_duo(config) \ No newline at end of file + pass \ No newline at end of file