diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py
index 756d5d1b0b91ad300403ffbfc5ecc2987fc2f7dd..02357705e7b03e1d133ff61be56a4007fa5cba64 100644
--- a/barlow_twin_like/main.py
+++ b/barlow_twin_like/main.py
@@ -7,6 +7,7 @@ import torch
 import wandb as wdb
 from matplotlib import pyplot as plt
 from sklearn.metrics import confusion_matrix
+from sympy.holonomic.holonomic import domain_for_table
 from torch import optim, nn
 
 from model import BarlowTwins, BaseClassifier
@@ -77,13 +78,15 @@ def save_model(model, path):
     torch.save(model.state_dict(), path)
 
 
-def train_representation(model, data_train, optimizer, epoch, wandb):
+def train_representation(model, data_train, optimizer, epoch, args):
     model.train()
     losses = 0.
     for param in model.parameters():
         param.requires_grad = True
 
-    for img, img_ref in data_train:
+    for step,(img, img_ref) in enumerate(data_train):
+        if args.opti=='LARS':
+            adjust_learning_rate(args, optimizer, data_train, step)
         img = img.float()
         img_ref = img_ref.float()
         if torch.cuda.is_available():
@@ -97,7 +100,7 @@ def train_representation(model, data_train, optimizer, epoch, wandb):
     losses = losses / len(data_train.dataset)
     print('Train epoch {}, loss : {:.3f}'.format(epoch, losses))
 
-    if wandb is not None:
+    if args.wandb is not None:
         wdb.log({"train loss": losses, 'train epoch': epoch})
 
     return losses
@@ -291,9 +294,7 @@ def run():
 
     best_loss = np.inf
     for e in range(args.epoches):
-        if args.opti=='LARS':
-            adjust_learning_rate(args, optimizer, data_train, e)
-        _ = train_representation(model, data_train, optimizer, e, args.wandb)
+        _ = train_representation(model, data_train, optimizer, e, args)
         if e % args.eval_inter == 0:
             loss = test_representation(model, data_val, e, args.wandb)
     #         if loss < best_loss: