diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py index 756d5d1b0b91ad300403ffbfc5ecc2987fc2f7dd..02357705e7b03e1d133ff61be56a4007fa5cba64 100644 --- a/barlow_twin_like/main.py +++ b/barlow_twin_like/main.py @@ -7,6 +7,7 @@ import torch import wandb as wdb from matplotlib import pyplot as plt from sklearn.metrics import confusion_matrix +from sympy.holonomic.holonomic import domain_for_table from torch import optim, nn from model import BarlowTwins, BaseClassifier @@ -77,13 +78,15 @@ def save_model(model, path): torch.save(model.state_dict(), path) -def train_representation(model, data_train, optimizer, epoch, wandb): +def train_representation(model, data_train, optimizer, epoch, args): model.train() losses = 0. for param in model.parameters(): param.requires_grad = True - for img, img_ref in data_train: + for step,(img, img_ref) in enumerate(data_train): + if args.opti=='LARS': + adjust_learning_rate(args, optimizer, data_train, step) img = img.float() img_ref = img_ref.float() if torch.cuda.is_available(): @@ -97,7 +100,7 @@ def train_representation(model, data_train, optimizer, epoch, wandb): losses = losses / len(data_train.dataset) print('Train epoch {}, loss : {:.3f}'.format(epoch, losses)) - if wandb is not None: + if args.wandb is not None: wdb.log({"train loss": losses, 'train epoch': epoch}) return losses @@ -291,9 +294,7 @@ def run(): best_loss = np.inf for e in range(args.epoches): - if args.opti=='LARS': - adjust_learning_rate(args, optimizer, data_train, e) - _ = train_representation(model, data_train, optimizer, e, args.wandb) + _ = train_representation(model, data_train, optimizer, e, args) if e % args.eval_inter == 0: loss = test_representation(model, data_val, e, args.wandb) # if loss < best_loss: