Skip to content
Snippets Groups Projects
Commit 65c422aa authored by Léo Calmettes's avatar Léo Calmettes
Browse files

modifié : dataset/dataset.py

	modifié :         main.py
parent 3fdf0877
No related branches found
No related tags found
No related merge requests found
...@@ -219,9 +219,9 @@ class ImageFolderDuo(data.Dataset): ...@@ -219,9 +219,9 @@ class ImageFolderDuo(data.Dataset):
def load_data_duo(base_dir, batch_size, args, shuffle=True): def load_data_duo(base_dir, batch_size, args, shuffle=True):
train_transform = transforms.Compose( train_transform = transforms.Compose(
[Random_erasing(args.augment_args[0], args.augment_args[3]), [#Random_erasing(args.augment_args[0], args.augment_args[3]),
Random_int_noise(args.augment_args[1], args.augment_args[4]), #Random_int_noise(args.augment_args[1], args.augment_args[4]),
Random_shift_rt(args.augment_args[2], args.augment_args[5], args.augment_args[6]), #Random_shift_rt(args.augment_args[2], args.augment_args[5], args.augment_args[6]),
transforms.Resize((224, 224)), transforms.Resize((224, 224)),
Threshold_noise(args.noise_threshold), Threshold_noise(args.noise_threshold),
Log_normalisation(), Log_normalisation(),
......
...@@ -254,10 +254,10 @@ def run_duo(args): ...@@ -254,10 +254,10 @@ def run_duo(args):
plt.plot(train_loss) plt.plot(train_loss)
plt.plot(train_loss) plt.plot(train_loss)
plt.savefig(f'output/training_plot_model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') plt.savefig(f'output/Atraining_plot_model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png')
#load and evaluate best model #load and evaluate best model
load_model(model, args.save_path) load_model(model, args.save_path)
make_prediction_duo(model,data_test, f'output/model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') make_prediction_duo(model,data_test, f'output/Amodel_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png')
return best_loss,best_acc return best_loss,best_acc
def make_prediction_duo(model, data, f_name): def make_prediction_duo(model, data, f_name):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment