Skip to content
Snippets Groups Projects
Commit d8803e3c authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : LARS opti

parent df1671e9
No related branches found
No related tags found
No related merge requests found
......@@ -259,6 +259,10 @@ def run():
model = BarlowTwins(args)
classifier = BaseClassifier(args,n_classes=n_classes)
model.float()
param_weights = []
param_biases = []
classifier.float()
# load weight
if args.pretrain_path is not None:
......@@ -270,10 +274,18 @@ def run():
model = model.cuda()
classifier = classifier.cuda()
classifier_optimizer = optim.Adam(model.parameters(), lr=args.lr)
if args.opti == 'adam':
optimizer = optim.Adam(model.parameters(), lr=args.lr)
elif args.opti=='LARS':
optimizer = LARS(model.parameters(),lr=0)
for param in model.parameters():
if param.ndim == 1:
param_biases.append(param)
else:
param_weights.append(param)
parameters = [{'params': param_weights}, {'params': param_biases}]
optimizer = LARS(parameters,lr=0)
else :
optimizer = optim.SGD(model.parameters(),lr=args.lr)
......@@ -293,7 +305,7 @@ def run():
param.requires_grad = False
for e in range(args.classification_epoches):
train_classification(model, classifier, data_train_classifier, optimizer, e, args.wandb)
train_classification(model, classifier, data_train_classifier, classifier_optimizer, e, args.wandb)
test_classification(model, classifier, data_val_classifier, e, args.wandb)
make_prediction_duo(model, classifier, data_val_classifier, args.base_out+'_confusion_matrix_val.png')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment