From 65c422aa3da21c6503dfb60c59460bd8219b08db Mon Sep 17 00:00:00 2001 From: lcalmettes <leo.calmettes@etu.ec-lyon.fr> Date: Mon, 12 May 2025 14:48:13 +0200 Subject: [PATCH] =?UTF-8?q?=09modifi=C3=A9=C2=A0:=20=20=20=20=20=20=20=20?= =?UTF-8?q?=20dataset/dataset.py=20=09modifi=C3=A9=C2=A0:=20=20=20=20=20?= =?UTF-8?q?=20=20=20=20main.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataset/dataset.py | 6 +++--- main.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dataset/dataset.py b/dataset/dataset.py index 5119a73..e06b3f1 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -219,9 +219,9 @@ class ImageFolderDuo(data.Dataset): def load_data_duo(base_dir, batch_size, args, shuffle=True): train_transform = transforms.Compose( - [Random_erasing(args.augment_args[0], args.augment_args[3]), - Random_int_noise(args.augment_args[1], args.augment_args[4]), - Random_shift_rt(args.augment_args[2], args.augment_args[5], args.augment_args[6]), + [#Random_erasing(args.augment_args[0], args.augment_args[3]), + #Random_int_noise(args.augment_args[1], args.augment_args[4]), + #Random_shift_rt(args.augment_args[2], args.augment_args[5], args.augment_args[6]), transforms.Resize((224, 224)), Threshold_noise(args.noise_threshold), Log_normalisation(), diff --git a/main.py b/main.py index 6aa0b89..2cd28d3 100644 --- a/main.py +++ b/main.py @@ -254,10 +254,10 @@ def run_duo(args): plt.plot(train_loss) plt.plot(train_loss) - plt.savefig(f'output/training_plot_model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') + plt.savefig(f'output/Atraining_plot_model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') #load and evaluate best model load_model(model, args.save_path) - make_prediction_duo(model,data_test, f'output/model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') + make_prediction_duo(model,data_test, f'output/Amodel_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') return best_loss,best_acc def make_prediction_duo(model, data, f_name): -- GitLab