Skip to content
Snippets Groups Projects
Commit df1671e9 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : lars optimizer and weight scheduling according to barlow twin

parent 28d81131
No related branches found
No related tags found
No related merge requests found
...@@ -58,7 +58,7 @@ class LARS(optim.Optimizer): ...@@ -58,7 +58,7 @@ class LARS(optim.Optimizer):
def adjust_learning_rate(args, optimizer, loader, step): def adjust_learning_rate(args, optimizer, loader, step):
max_steps = args.epochs * len(loader) max_steps = args.epoches * len(loader)
warmup_steps = 10 * len(loader) warmup_steps = 10 * len(loader)
base_lr = args.batch_size / 256 base_lr = args.batch_size / 256
if step < warmup_steps: if step < warmup_steps:
...@@ -69,8 +69,8 @@ def adjust_learning_rate(args, optimizer, loader, step): ...@@ -69,8 +69,8 @@ def adjust_learning_rate(args, optimizer, loader, step):
q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
end_lr = base_lr * 0.001 end_lr = base_lr * 0.001
lr = base_lr * q + end_lr * (1 - q) lr = base_lr * q + end_lr * (1 - q)
optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights optimizer.param_groups[0]['lr'] = lr * 0.2
optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases optimizer.param_groups[1]['lr'] = lr * 0.0048
def save_model(model, path): def save_model(model, path):
print('Model saved') print('Model saved')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment