diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py index 9b7624a92f1e40562a185559a0926c5cf76aaa2a..756d5d1b0b91ad300403ffbfc5ecc2987fc2f7dd 100644 --- a/barlow_twin_like/main.py +++ b/barlow_twin_like/main.py @@ -259,6 +259,10 @@ def run(): model = BarlowTwins(args) classifier = BaseClassifier(args,n_classes=n_classes) model.float() + + param_weights = [] + param_biases = [] + classifier.float() # load weight if args.pretrain_path is not None: @@ -270,10 +274,18 @@ def run(): model = model.cuda() classifier = classifier.cuda() + classifier_optimizer = optim.Adam(model.parameters(), lr=args.lr) if args.opti == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr) elif args.opti=='LARS': - optimizer = LARS(model.parameters(),lr=0) + for param in model.parameters(): + if param.ndim == 1: + param_biases.append(param) + else: + param_weights.append(param) + parameters = [{'params': param_weights}, {'params': param_biases}] + + optimizer = LARS(parameters,lr=0) else : optimizer = optim.SGD(model.parameters(),lr=args.lr) @@ -293,7 +305,7 @@ def run(): param.requires_grad = False for e in range(args.classification_epoches): - train_classification(model, classifier, data_train_classifier, optimizer, e, args.wandb) + train_classification(model, classifier, data_train_classifier, classifier_optimizer, e, args.wandb) test_classification(model, classifier, data_val_classifier, e, args.wandb) make_prediction_duo(model, classifier, data_val_classifier, args.base_out+'_confusion_matrix_val.png')