diff --git a/image_ref/main_sweep.py b/image_ref/main_sweep.py new file mode 100644 index 0000000000000000000000000000000000000000..b136e4a83917aee146ccfea85fdb32d3130fe8fa --- /dev/null +++ b/image_ref/main_sweep.py @@ -0,0 +1,31 @@ + +import wandb as wdb +from config import load_args_contrastive + + +if __name__ == '__main__': + args = load_args_contrastive() + sweep_configuration = { + "program": "sweep_train.py", + "method": "random", + "metric": {"goal": "minimize", "name": "validation loss"}, + "parameters": { + "epoches":{"value": 50}, + "eval_inter":{"value": 1}, + "noise_threshold": {"distribution" : "log_uniform_values", "max": 10000., "min": 0.0001}, + "lr": {"distribution" : "log_uniform_values", "max": 0.01, "min": 0.0001}, + "batch_size": {"value": 64}, + "positive_prop": {"distribution" : "uniform","max": 95., "min": 5.}, + "opti": {"value": "adam"}, + "model": {"value": "resnet18"}, + "sampler": {"values": ["random","balanced"]}, + "dataset_train_dir": {"value": "data/processed_data_wiff/npy_image/train_data"}, + "dataset_val_dir": {"value": "data/processed_data_wiff/npy_image/test_data"}, + "dataset_ref_dir": {"values": ["image_ref/img_ref","image_ref/img_ref_count_th_10","image_ref/img_ref_count_th_5"]}, + }, + } + sweep_id = wdb.sweep(sweep=sweep_configuration, project="param_sweep_contrastive") + + sweep = wdb.controller(sweep_id) + sweep.configure_controller(type="local") + sweep.run() diff --git a/image_ref/sweep_train.py b/image_ref/sweep_train.py new file mode 100644 index 0000000000000000000000000000000000000000..056df14c20e5ee20512ab4165a15959539e3024a --- /dev/null +++ b/image_ref/sweep_train.py @@ -0,0 +1,139 @@ +import os +import wandb as wdb +from dataset_ref import load_data_duo +import torch +import torch.nn as nn +from model import Classification_model_duo_contrastive +import torch.optim as optim + +def train_duo(model, data_train, optimizer, loss_function, epoch, wandb): + model.train() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = True + + for imaer, imana, img_ref, label in data_train: + imaer = imaer.float() + imana = imana.float() + img_ref = img_ref.float() + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + img_ref = img_ref.cuda() + label = label.cuda() + pred_logits = model.forward(imaer, imana, img_ref) + pred_class = torch.argmax(pred_logits, dim=1) + acc += (pred_class == label).sum().item() + loss = loss_function(pred_logits, label) + losses += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses = losses / len(data_train.dataset) + acc = acc / len(data_train.dataset) + print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch, losses, acc)) + + wdb.log({"train loss": losses, 'train epoch': epoch, "train contrastive accuracy": acc}) + + return losses, acc + + +def val_duo(model, data_test, loss_function, epoch, wandb): + model.eval() + losses = 0. + acc = 0. + acc_contrastive = 0. + for param in model.parameters(): + param.requires_grad = False + + for imaer, imana, img_ref, label in data_test: + imaer = imaer.float() + imana = imana.float() + img_ref = img_ref.float() + imaer = imaer.transpose(0, 1) + imana = imana.transpose(0, 1) + img_ref = img_ref.transpose(0, 1) + label = label.transpose(0, 1) + label = label.squeeze() + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + img_ref = img_ref.cuda() + label = label.cuda() + label_class = torch.argmin(label).data.cpu().numpy() + pred_logits = model.forward(imaer, imana, img_ref) + pred_class = torch.argmax(pred_logits[:, 0]).tolist() + acc_contrastive += ( + torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item() + acc += (pred_class == label_class) + loss = loss_function(pred_logits, label) + losses += loss.item() + losses = losses / (label.shape[0] * len(data_test.dataset)) + acc = acc / (len(data_test.dataset)) + acc_contrastive = acc_contrastive / (label.shape[0] * len(data_test.dataset)) + print('Test epoch {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch, losses, acc, + acc_contrastive)) + + wdb.log({"validation loss": losses, 'validation epoch': epoch, "validation classification accuracy": acc, + "validation contrastive accuracy": acc_contrastive}) + + return losses, acc, acc_contrastive + + +def run_duo(args): + # wandb init + os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd' + + os.environ["WANDB_MODE"] = "offline" + os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run") + + wdb.init(project="param_sweep_contrastive", dir='./wandb_run') + + print('Wandb initialised') + # load data + data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, + base_dir_val=args.dataset_val_dir, + base_dir_test=None, + batch_size=args.batch_size, + ref_dir=args.dataset_ref_dir, + positive_prop=args.positive_prop, sampler=args.sampler) + + # load model + model = Classification_model_duo_contrastive(model=args.model, n_class=2) + model.float() + # move parameters to GPU + if torch.cuda.is_available(): + print('Model loaded on GPU') + model = model.cuda() + + # init accumulators + best_loss = 100 + train_acc = [] + train_loss = [] + val_acc = [] + val_cont_acc = [] + val_loss = [] + # init training + loss_function = nn.CrossEntropyLoss() + if args.opti == 'adam': + optimizer = optim.Adam(model.parameters(), lr=args.lr) + # train model + for e in range(args.epoches): + loss, acc = train_duo(model, data_train, optimizer, loss_function, e, args.wandb) + train_loss.append(loss) + train_acc.append(acc) + if e % args.eval_inter == 0: + loss, acc, acc_contrastive = val_duo(model, data_val_batch, loss_function, e, args.wandb) + val_loss.append(loss) + val_acc.append(acc) + val_cont_acc.append(acc_contrastive) + wdb.finish() + + +if __name__ == '__main__': + config = wdb.config + print(config) + run_duo(config) \ No newline at end of file