From 28d81131802047b22176c3ca00a5c6b6fa053f4f Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Wed, 21 May 2025 13:46:09 +0200
Subject: [PATCH] add : lars optimizer and weight scheduling according to
 barlow twin

---
 barlow_twin_like/main.py | 68 ++++++++++++++++++++++++++++++++++++++--
 1 file changed, 66 insertions(+), 2 deletions(-)

diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py
index fef6992..08a4f15 100644
--- a/barlow_twin_like/main.py
+++ b/barlow_twin_like/main.py
@@ -1,3 +1,4 @@
+import math
 import os
 import seaborn as sn
 import numpy as np
@@ -12,6 +13,65 @@ from model import BarlowTwins, BaseClassifier
 from dataset_barlow import load_data_duo
 from config import load_args_barlow
 
+class LARS(optim.Optimizer):
+    def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
+                 weight_decay_filter=False, lars_adaptation_filter=False):
+        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
+                        eta=eta, weight_decay_filter=weight_decay_filter,
+                        lars_adaptation_filter=lars_adaptation_filter)
+        super().__init__(params, defaults)
+
+
+    def exclude_bias_and_norm(self, p):
+        return p.ndim == 1
+
+    @torch.no_grad()
+    def step(self):
+        for g in self.param_groups:
+            for p in g['params']:
+                dp = p.grad
+
+                if dp is None:
+                    continue
+
+                if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p):
+                    dp = dp.add(p, alpha=g['weight_decay'])
+
+                if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p):
+                    param_norm = torch.norm(p)
+                    update_norm = torch.norm(dp)
+                    one = torch.ones_like(param_norm)
+                    q = torch.where(param_norm > 0.,
+                                    torch.where(update_norm > 0,
+                                                (g['eta'] * param_norm / update_norm), one), one)
+                    dp = dp.mul(q)
+
+                param_state = self.state[p]
+                if 'mu' not in param_state:
+                    param_state['mu'] = torch.zeros_like(p)
+                mu = param_state['mu']
+                mu.mul_(g['momentum']).add_(dp)
+
+                p.add_(mu, alpha=-g['lr'])
+
+
+
+
+def adjust_learning_rate(args, optimizer, loader, step):
+    max_steps = args.epochs * len(loader)
+    warmup_steps = 10 * len(loader)
+    base_lr = args.batch_size / 256
+    if step < warmup_steps:
+        lr = base_lr * step / warmup_steps
+    else:
+        step -= warmup_steps
+        max_steps -= warmup_steps
+        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
+        end_lr = base_lr * 0.001
+        lr = base_lr * q + end_lr * (1 - q)
+    optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights
+    optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases
+
 def save_model(model, path):
     print('Model saved')
     torch.save(model.state_dict(), path)
@@ -212,11 +272,15 @@ def run():
 
     if args.opti == 'adam':
         optimizer = optim.Adam(model.parameters(), lr=args.lr)
-    else:
-        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
+    elif args.opti=='LARS':
+        optimizer = LARS(model.parameters(),lr=0)
+    else :
+        optimizer = optim.SGD(model.parameters(),lr=args.lr)
 
     best_loss = np.inf
     for e in range(args.epoches):
+        if args.opti=='LARS':
+            adjust_learning_rate(args, optimizer, data_train, e)
         _ = train_representation(model, data_train, optimizer, e, args.wandb)
         if e % args.eval_inter == 0:
             loss = test_representation(model, data_val, e, args.wandb)
-- 
GitLab