import os import wandb as wdb from dataset_ref import load_data_duo import torch import torch.nn as nn from model import Classification_model_duo_contrastive import torch.optim as optim def train_duo(model, data_train, optimizer, loss_function, epoch, wandb): model.train() losses = 0. acc = 0. for param in model.parameters(): param.requires_grad = True for imaer, imana, img_ref, label in data_train: imaer = imaer.float() imana = imana.float() img_ref = img_ref.float() label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() img_ref = img_ref.cuda() label = label.cuda() pred_logits = model.forward(imaer, imana, img_ref) pred_class = torch.argmax(pred_logits, dim=1) acc += (pred_class == label).sum().item() loss = loss_function(pred_logits, label) losses += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() losses = losses / len(data_train.dataset) acc = acc / len(data_train.dataset) print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch, losses, acc)) wdb.log({"train loss": losses, 'train epoch': epoch, "train contrastive accuracy": acc}) return losses, acc def val_duo(model, data_test, loss_function, epoch, wandb): model.eval() losses = 0. acc = 0. acc_contrastive = 0. for param in model.parameters(): param.requires_grad = False for imaer, imana, img_ref, label in data_test: imaer = imaer.float() imana = imana.float() img_ref = img_ref.float() imaer = imaer.transpose(0, 1) imana = imana.transpose(0, 1) img_ref = img_ref.transpose(0, 1) label = label.transpose(0, 1) label = label.squeeze() label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() img_ref = img_ref.cuda() label = label.cuda() label_class = torch.argmin(label).data.cpu().numpy() pred_logits = model.forward(imaer, imana, img_ref) pred_class = torch.argmax(pred_logits[:, 0]).tolist() acc_contrastive += ( torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item() acc += (pred_class == label_class) loss = loss_function(pred_logits, label) losses += loss.item() losses = losses / (label.shape[0] * len(data_test.dataset)) acc = acc / (len(data_test.dataset)) acc_contrastive = acc_contrastive / (label.shape[0] * len(data_test.dataset)) print('Test epoch {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch, losses, acc, acc_contrastive)) wdb.log({"validation loss": losses, 'validation epoch': epoch, "validation classification accuracy": acc, "validation contrastive accuracy": acc_contrastive}) return losses, acc, acc_contrastive def run_duo(args): # wandb init os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd' os.environ["WANDB_MODE"] = "offline" os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run") wdb.init(project="param_sweep_contrastive", dir='./wandb_run') print('Wandb initialised') # load data data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_val=args.dataset_val_dir, base_dir_test=None, batch_size=args.batch_size, ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop, sampler=args.sampler) # load model model = Classification_model_duo_contrastive(model=args.model, n_class=2) model.float() # move parameters to GPU if torch.cuda.is_available(): print('Model loaded on GPU') model = model.cuda() # init accumulators best_loss = 100 train_acc = [] train_loss = [] val_acc = [] val_cont_acc = [] val_loss = [] # init training loss_function = nn.CrossEntropyLoss() if args.opti == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr) # train model for e in range(args.epoches): loss, acc = train_duo(model, data_train, optimizer, loss_function, e, args.wandb) train_loss.append(loss) train_acc.append(acc) if e % args.eval_inter == 0: loss, acc, acc_contrastive = val_duo(model, data_val_batch, loss_function, e, args.wandb) val_loss.append(loss) val_acc.append(acc) val_cont_acc.append(acc_contrastive) wdb.finish() if __name__ == '__main__': config = wdb.config print(config) run_duo(config)