Skip to content
Snippets Groups Projects
Commit 28d81131 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : lars optimizer and weight scheduling according to barlow twin

parent 2886a781
No related branches found
No related tags found
No related merge requests found
import math
import os
import seaborn as sn
import numpy as np
......@@ -12,6 +13,65 @@ from model import BarlowTwins, BaseClassifier
from dataset_barlow import load_data_duo
from config import load_args_barlow
class LARS(optim.Optimizer):
def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
weight_decay_filter=False, lars_adaptation_filter=False):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
eta=eta, weight_decay_filter=weight_decay_filter,
lars_adaptation_filter=lars_adaptation_filter)
super().__init__(params, defaults)
def exclude_bias_and_norm(self, p):
return p.ndim == 1
@torch.no_grad()
def step(self):
for g in self.param_groups:
for p in g['params']:
dp = p.grad
if dp is None:
continue
if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p):
dp = dp.add(p, alpha=g['weight_decay'])
if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p):
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(param_norm > 0.,
torch.where(update_norm > 0,
(g['eta'] * param_norm / update_norm), one), one)
dp = dp.mul(q)
param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)
mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
p.add_(mu, alpha=-g['lr'])
def adjust_learning_rate(args, optimizer, loader, step):
max_steps = args.epochs * len(loader)
warmup_steps = 10 * len(loader)
base_lr = args.batch_size / 256
if step < warmup_steps:
lr = base_lr * step / warmup_steps
else:
step -= warmup_steps
max_steps -= warmup_steps
q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
end_lr = base_lr * 0.001
lr = base_lr * q + end_lr * (1 - q)
optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights
optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases
def save_model(model, path):
print('Model saved')
torch.save(model.state_dict(), path)
......@@ -212,11 +272,15 @@ def run():
if args.opti == 'adam':
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
elif args.opti=='LARS':
optimizer = LARS(model.parameters(),lr=0)
else :
optimizer = optim.SGD(model.parameters(),lr=args.lr)
best_loss = np.inf
for e in range(args.epoches):
if args.opti=='LARS':
adjust_learning_rate(args, optimizer, data_train, e)
_ = train_representation(model, data_train, optimizer, e, args.wandb)
if e % args.eval_inter == 0:
loss = test_representation(model, data_val, e, args.wandb)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment