From 5c611b3582f13fcb6e5f9775eadc0c327ce1ca0f Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 16 Jun 2025 11:41:51 +0200 Subject: [PATCH] fix : frozen param --- main.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index f5a4d4e..3617217 100644 --- a/main.py +++ b/main.py @@ -135,10 +135,18 @@ def make_prediction(model, data, f_name): def train_duo(model, data_train, optimizer, loss_function, epoch): model.train() + losses = 0. acc = 0. - for param in model.parameters(): - param.requires_grad = True + for n, p in model.im_encoder.named_parameters(): + if n in ['fc.weight', 'fc.bias']: + p.requires_grad = True + else: + p.requires_grad = False + + for n, p in model.predictor.named_parameters(): + p.requires_grad = True + for imaer,imana, label in data_train: label = label.long() @@ -154,6 +162,7 @@ def train_duo(model, data_train, optimizer, loss_function, epoch): optimizer.zero_grad() loss.backward() optimizer.step() + losses = losses/len(data_train.dataset) acc = acc/len(data_train.dataset) print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) @@ -190,6 +199,7 @@ def run_duo(args): model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.classes)) else : model = Classification_model_duo_pretrained(model = args.model, n_class=len(data_train.dataset.classes)) + model.double() #load weight if args.pretrain_path is not None : -- GitLab