From d8803e3cf5513a255ace5d67f4c9d632a237493a Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Wed, 21 May 2025 14:17:13 +0200
Subject: [PATCH] fix : LARS opti

---
 barlow_twin_like/main.py | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py
index 9b7624a..756d5d1 100644
--- a/barlow_twin_like/main.py
+++ b/barlow_twin_like/main.py
@@ -259,6 +259,10 @@ def run():
     model = BarlowTwins(args)
     classifier = BaseClassifier(args,n_classes=n_classes)
     model.float()
+
+    param_weights = []
+    param_biases = []
+
     classifier.float()
     # load weight
     if args.pretrain_path is not None:
@@ -270,10 +274,18 @@ def run():
         model = model.cuda()
         classifier = classifier.cuda()
 
+    classifier_optimizer = optim.Adam(model.parameters(), lr=args.lr)
     if args.opti == 'adam':
         optimizer = optim.Adam(model.parameters(), lr=args.lr)
     elif args.opti=='LARS':
-        optimizer = LARS(model.parameters(),lr=0)
+        for param in model.parameters():
+            if param.ndim == 1:
+                param_biases.append(param)
+            else:
+                param_weights.append(param)
+        parameters = [{'params': param_weights}, {'params': param_biases}]
+
+        optimizer = LARS(parameters,lr=0)
     else :
         optimizer = optim.SGD(model.parameters(),lr=args.lr)
 
@@ -293,7 +305,7 @@ def run():
         param.requires_grad = False
 
     for e in range(args.classification_epoches):
-        train_classification(model, classifier, data_train_classifier, optimizer, e, args.wandb)
+        train_classification(model, classifier, data_train_classifier, classifier_optimizer, e, args.wandb)
         test_classification(model, classifier, data_val_classifier, e, args.wandb)
 
     make_prediction_duo(model, classifier, data_val_classifier, args.base_out+'_confusion_matrix_val.png')
-- 
GitLab