diff --git a/main.py b/main.py index f5a4d4e8172a9d23aaa128ecc2921ddb3323ef3b..3617217e15dded9ae61c5fb8adc7ba3f798e30d8 100644 --- a/main.py +++ b/main.py @@ -135,10 +135,18 @@ def make_prediction(model, data, f_name): def train_duo(model, data_train, optimizer, loss_function, epoch): model.train() + losses = 0. acc = 0. - for param in model.parameters(): - param.requires_grad = True + for n, p in model.im_encoder.named_parameters(): + if n in ['fc.weight', 'fc.bias']: + p.requires_grad = True + else: + p.requires_grad = False + + for n, p in model.predictor.named_parameters(): + p.requires_grad = True + for imaer,imana, label in data_train: label = label.long() @@ -154,6 +162,7 @@ def train_duo(model, data_train, optimizer, loss_function, epoch): optimizer.zero_grad() loss.backward() optimizer.step() + losses = losses/len(data_train.dataset) acc = acc/len(data_train.dataset) print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) @@ -190,6 +199,7 @@ def run_duo(args): model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.classes)) else : model = Classification_model_duo_pretrained(model = args.model, n_class=len(data_train.dataset.classes)) + model.double() #load weight if args.pretrain_path is not None :