From 5c611b3582f13fcb6e5f9775eadc0c327ce1ca0f Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Mon, 16 Jun 2025 11:41:51 +0200
Subject: [PATCH] fix : frozen param

---
 main.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/main.py b/main.py
index f5a4d4e..3617217 100644
--- a/main.py
+++ b/main.py
@@ -135,10 +135,18 @@ def make_prediction(model, data, f_name):
 
 def train_duo(model, data_train, optimizer, loss_function, epoch):
     model.train()
+
     losses = 0.
     acc = 0.
-    for param in model.parameters():
-        param.requires_grad = True
+    for n, p in model.im_encoder.named_parameters():
+        if n in ['fc.weight', 'fc.bias']:
+            p.requires_grad = True
+        else:
+            p.requires_grad = False
+
+    for n, p in model.predictor.named_parameters():
+        p.requires_grad = True
+
 
     for imaer,imana, label in data_train:
         label = label.long()
@@ -154,6 +162,7 @@ def train_duo(model, data_train, optimizer, loss_function, epoch):
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
+
     losses = losses/len(data_train.dataset)
     acc = acc/len(data_train.dataset)
     print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
@@ -190,6 +199,7 @@ def run_duo(args):
         model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.classes))
     else :
         model = Classification_model_duo_pretrained(model = args.model, n_class=len(data_train.dataset.classes))
+
     model.double()
     #load weight
     if args.pretrain_path is not None :
-- 
GitLab