From d8803e3cf5513a255ace5d67f4c9d632a237493a Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Wed, 21 May 2025 14:17:13 +0200 Subject: [PATCH] fix : LARS opti --- barlow_twin_like/main.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py index 9b7624a..756d5d1 100644 --- a/barlow_twin_like/main.py +++ b/barlow_twin_like/main.py @@ -259,6 +259,10 @@ def run(): model = BarlowTwins(args) classifier = BaseClassifier(args,n_classes=n_classes) model.float() + + param_weights = [] + param_biases = [] + classifier.float() # load weight if args.pretrain_path is not None: @@ -270,10 +274,18 @@ def run(): model = model.cuda() classifier = classifier.cuda() + classifier_optimizer = optim.Adam(model.parameters(), lr=args.lr) if args.opti == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr) elif args.opti=='LARS': - optimizer = LARS(model.parameters(),lr=0) + for param in model.parameters(): + if param.ndim == 1: + param_biases.append(param) + else: + param_weights.append(param) + parameters = [{'params': param_weights}, {'params': param_biases}] + + optimizer = LARS(parameters,lr=0) else : optimizer = optim.SGD(model.parameters(),lr=args.lr) @@ -293,7 +305,7 @@ def run(): param.requires_grad = False for e in range(args.classification_epoches): - train_classification(model, classifier, data_train_classifier, optimizer, e, args.wandb) + train_classification(model, classifier, data_train_classifier, classifier_optimizer, e, args.wandb) test_classification(model, classifier, data_val_classifier, e, args.wandb) make_prediction_duo(model, classifier, data_val_classifier, args.base_out+'_confusion_matrix_val.png') -- GitLab