diff --git a/AugmentTests.py b/AugmentTests.py index c6f25ffbf240f31f5c7a04fc439d4528b08ee79c..b8dd7787d74b659d05913c4b7e3fe2e7dfc37e7b 100644 --- a/AugmentTests.py +++ b/AugmentTests.py @@ -15,19 +15,34 @@ if __name__ == "__main__": # records.to_csv("output/perfs.csv",index = False) - #Et enfin le rt-shift - mean = 0 + # #Ensuite le rt-shift + # mean = 0 + # nTests = 5 + # records = pd.read_csv("output/perfs.csv") + # for prob in [k/5 for k in range(1,6)]: + # args.augment_args[2] = prob + # args.augment_args[5] = mean + # for std in [k/2 for k in range(5,25,5)]: + # args.augment_args[6] = std + # losses = np.zeros(nTests); accs = np.zeros(nTests) + # for random in range(nTests): + # args.random_state = random + # losses[random], accs[random] = run_duo(args) + # records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean().item(),losses.std().item(),accs.mean().item(),accs.std().item()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) + # records.to_csv("output/perfs.csv",index = False) + + #Et enfin le int_shift + nTests = 5 records = pd.read_csv("output/perfs.csv") for prob in [k/5 for k in range(1,6)]: - args.augment_args[2] = prob - args.augment_args[5] = mean - for std in [k/2 for k in range(5,25,5)]: - args.augment_args[6] = std - losses = np.zeros(5); accs = np.zeros(5) - for random in range(5): + args.augment_args[1] = prob + for borne_max in [1.25,1.5,1.75,2.,2.25,2.5]: + args.augment_args[4] = borne_max + losses = np.zeros(nTests); accs = np.zeros(nTests) + for random in range(nTests): args.random_state = random losses[random], accs[random] = run_duo(args) - records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean().item(),losses.std().item(),accs.mean().item(),accs.std().item()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) + records = pd.concat([records,pd.DataFrame([[f"intShift prob{prob} max{borne_max}",losses.mean().item(),losses.std().item(),accs.mean().item(),accs.std().item()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) records.to_csv("output/perfs.csv",index = False) \ No newline at end of file diff --git a/config/config.py b/config/config.py index 63b5999c22eef360fb137a88c2320ce3282c589d..ae2861d93bbeea900fa59dfbbfb0381c937af0a2 100644 --- a/config/config.py +++ b/config/config.py @@ -5,7 +5,7 @@ def load_args(): parser = argparse.ArgumentParser() parser.add_argument('--epoches', type=int, default=20) parser.add_argument('--eval_inter', type=int, default=1) - parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0.5,0.1,0.1,0.,7.5]) + parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0,0.1,0.1,0.,7.5]) parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--optim', type = str, default = "Adam") diff --git a/dataset/dataset.py b/dataset/dataset.py index 7d881e266ef6c01089b33d8f2227947e1b35f438..1f1fcd664b991698179882306ab8a0d3b04c3614 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -220,8 +220,8 @@ class ImageFolderDuo(data.Dataset): def load_data_duo(base_dir, batch_size, args, shuffle=True): train_transform = transforms.Compose( [#Random_erasing(args.augment_args[0], args.augment_args[3]), - #Random_int_noise(args.augment_args[1], args.augment_args[4]), - Random_shift_rt(args.augment_args[2], args.augment_args[5], args.augment_args[6]), + Random_int_noise(args.augment_args[1], args.augment_args[4]), + #Random_shift_rt(args.augment_args[2], args.augment_args[5], args.augment_args[6]), transforms.Resize((224, 224)), Threshold_noise(args.noise_threshold), Log_normalisation(), diff --git a/main.py b/main.py index 2cd28d3050f8ab7ae463989b269f1b2bfc129394..7090c536d730452c7789e9bd9e338863838114bc 100644 --- a/main.py +++ b/main.py @@ -249,15 +249,17 @@ def run_duo(args): best_loss = loss best_acc = acc # plot and save training figs + plt.figure(figsize=(14, 9)) plt.plot(train_loss) plt.plot(val_loss) plt.plot(train_loss) plt.plot(train_loss) - - plt.savefig(f'output/Atraining_plot_model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') + plt.savefig(f'output/dataAugments/intShift prob {args.augment_args[1]} borne_max {args.augment_args[4]}.png') + #plt.savefig(f'output/Atraining_plot_model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') #load and evaluate best model load_model(model, args.save_path) - make_prediction_duo(model,data_test, f'output/Amodel_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') + make_prediction_duo(model,data_test, f'output/dataAugments/intShift prob {args.augment_args[1]} borne_max {args.augment_args[4]}.png') + #make_prediction_duo(model,data_test, f'output/Amodel_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') return best_loss,best_acc def make_prediction_duo(model, data, f_name): @@ -312,7 +314,7 @@ def make_prediction_duo(model, data, f_name): confiance_df = pd.DataFrame(confiance, index=[i for i in classes], columns=[i for i in classes]) plt.figure(figsize=(14, 9)) - sn.heatmap(confiance_df, annot=confiance.astype("<U4")+np.full(confiance.shape,"\u00B1")+variance.astype("<U4"),fmt='') + sn.heatmap(confiance_df, annot=confiance.astype("<U4")+np.full(confiance.shape,"\u00B1")+variance.astype("<U3"),fmt='') confiName = f_name.split("/") confiName[-1] = "confiance_matrix_"+confiName[-1] plt.savefig('/'.join(confiName)) diff --git a/output/best_model.pt b/output/best_model.pt index 9b11be368a33bf85823fdf9fd56de298944968a6..e5bd5ca6efe6ba8bbc977f68e482226458907780 100644 Binary files a/output/best_model.pt and b/output/best_model.pt differ