diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py index 353a01593a46f0aec305f4d3cb72bef86487569f..fef6992326dda8118b82536fe2ab47418a1d6887 100644 --- a/barlow_twin_like/main.py +++ b/barlow_twin_like/main.py @@ -220,11 +220,11 @@ def run(): _ = train_representation(model, data_train, optimizer, e, args.wandb) if e % args.eval_inter == 0: loss = test_representation(model, data_val, e, args.wandb) - if loss < best_loss: - save_model(model, args.save_path) - best_loss = loss - - model.load_state_dict((torch.load(args.save_path, weights_only=True))) #load best model + # if loss < best_loss: + # save_model(model, args.save_path) + # best_loss = loss + # + # model.load_state_dict((torch.load(args.save_path, weights_only=True))) #load best model for param in model.parameters(): # freezing representations before classifier training param.requires_grad = False