From 28d81131802047b22176c3ca00a5c6b6fa053f4f Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Wed, 21 May 2025 13:46:09 +0200 Subject: [PATCH] add : lars optimizer and weight scheduling according to barlow twin --- barlow_twin_like/main.py | 68 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py index fef6992..08a4f15 100644 --- a/barlow_twin_like/main.py +++ b/barlow_twin_like/main.py @@ -1,3 +1,4 @@ +import math import os import seaborn as sn import numpy as np @@ -12,6 +13,65 @@ from model import BarlowTwins, BaseClassifier from dataset_barlow import load_data_duo from config import load_args_barlow +class LARS(optim.Optimizer): + def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001, + weight_decay_filter=False, lars_adaptation_filter=False): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, + eta=eta, weight_decay_filter=weight_decay_filter, + lars_adaptation_filter=lars_adaptation_filter) + super().__init__(params, defaults) + + + def exclude_bias_and_norm(self, p): + return p.ndim == 1 + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g['params']: + dp = p.grad + + if dp is None: + continue + + if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p): + dp = dp.add(p, alpha=g['weight_decay']) + + if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p): + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where(param_norm > 0., + torch.where(update_norm > 0, + (g['eta'] * param_norm / update_norm), one), one) + dp = dp.mul(q) + + param_state = self.state[p] + if 'mu' not in param_state: + param_state['mu'] = torch.zeros_like(p) + mu = param_state['mu'] + mu.mul_(g['momentum']).add_(dp) + + p.add_(mu, alpha=-g['lr']) + + + + +def adjust_learning_rate(args, optimizer, loader, step): + max_steps = args.epochs * len(loader) + warmup_steps = 10 * len(loader) + base_lr = args.batch_size / 256 + if step < warmup_steps: + lr = base_lr * step / warmup_steps + else: + step -= warmup_steps + max_steps -= warmup_steps + q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + end_lr = base_lr * 0.001 + lr = base_lr * q + end_lr * (1 - q) + optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights + optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases + def save_model(model, path): print('Model saved') torch.save(model.state_dict(), path) @@ -212,11 +272,15 @@ def run(): if args.opti == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr) - else: - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) + elif args.opti=='LARS': + optimizer = LARS(model.parameters(),lr=0) + else : + optimizer = optim.SGD(model.parameters(),lr=args.lr) best_loss = np.inf for e in range(args.epoches): + if args.opti=='LARS': + adjust_learning_rate(args, optimizer, data_train, e) _ = train_representation(model, data_train, optimizer, e, args.wandb) if e % args.eval_inter == 0: loss = test_representation(model, data_val, e, args.wandb) -- GitLab