diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py index 08a4f15aae540d105b52c38d4ca2bff4c91093da..9b7624a92f1e40562a185559a0926c5cf76aaa2a 100644 --- a/barlow_twin_like/main.py +++ b/barlow_twin_like/main.py @@ -58,7 +58,7 @@ class LARS(optim.Optimizer): def adjust_learning_rate(args, optimizer, loader, step): - max_steps = args.epochs * len(loader) + max_steps = args.epoches * len(loader) warmup_steps = 10 * len(loader) base_lr = args.batch_size / 256 if step < warmup_steps: @@ -69,8 +69,8 @@ def adjust_learning_rate(args, optimizer, loader, step): q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) end_lr = base_lr * 0.001 lr = base_lr * q + end_lr * (1 - q) - optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights - optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases + optimizer.param_groups[0]['lr'] = lr * 0.2 + optimizer.param_groups[1]['lr'] = lr * 0.0048 def save_model(model, path): print('Model saved')