Skip to content
Snippets Groups Projects
Commit 5c611b35 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : frozen param

parent 600f9734
No related branches found
No related tags found
No related merge requests found
......@@ -135,10 +135,18 @@ def make_prediction(model, data, f_name):
def train_duo(model, data_train, optimizer, loss_function, epoch):
model.train()
losses = 0.
acc = 0.
for param in model.parameters():
param.requires_grad = True
for n, p in model.im_encoder.named_parameters():
if n in ['fc.weight', 'fc.bias']:
p.requires_grad = True
else:
p.requires_grad = False
for n, p in model.predictor.named_parameters():
p.requires_grad = True
for imaer,imana, label in data_train:
label = label.long()
......@@ -154,6 +162,7 @@ def train_duo(model, data_train, optimizer, loss_function, epoch):
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses = losses/len(data_train.dataset)
acc = acc/len(data_train.dataset)
print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
......@@ -190,6 +199,7 @@ def run_duo(args):
model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.classes))
else :
model = Classification_model_duo_pretrained(model = args.model, n_class=len(data_train.dataset.classes))
model.double()
#load weight
if args.pretrain_path is not None :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment