From 8f15534fbfefe4add5cfcd448b3c742f5a012e40 Mon Sep 17 00:00:00 2001 From: lcalmettes <leo.calmettes@etu.ec-lyon.fr> Date: Mon, 12 May 2025 17:03:28 +0200 Subject: [PATCH] =?UTF-8?q?=09modifi=C3=A9=C2=A0:=20=20=20=20=20=20=20=20?= =?UTF-8?q?=20AugmentTests.py=20=09modifi=C3=A9=C2=A0:=20=20=20=20=20=20?= =?UTF-8?q?=20=20=20config/config.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AugmentTests.py | 53 ++++++++++++++---------------------------------- config/config.py | 2 +- 2 files changed, 16 insertions(+), 39 deletions(-) diff --git a/AugmentTests.py b/AugmentTests.py index b71b81e..deecc9a 100644 --- a/AugmentTests.py +++ b/AugmentTests.py @@ -12,44 +12,21 @@ if __name__ == "__main__": args.random_state = random losses[random], accs[random] = run_duo(args) records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"]) - records.to_csv("output/DataAugment/perfs.csv",index = False) + records.to_csv("output/perfs.csv",index = False) - #On continue avec le random_erasing - for prob in [k/20 for k in range(1,21)]: - args.augment_args[0] = prob - for prop in [k/20 for k in range(1,21)]: - args.augment_args[3] = prop - losses = np.zeros(5); accs = np.zeros(5) - for random in range(5): - args.random_state = random - losses[random], accs[random] = run_duo(args) - records = pd.concat([records,pd.DataFrame([[f"erasing prob{prob} prop{prop}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) - records.to_csv("output/DataAugment/perfs.csv",index = False) - - #Puis le int shift - for prob in [k/20 for k in range(1,21)]: - args.augment_args[1] = prob - for maximum in ([k/10 for k in range(11,20)]+[k for k in range(2,11)]): - args.augment_args[4] = maximum - losses = np.zeros(5); accs = np.zeros(5) - for random in range(5): - args.random_state = random - losses[random], accs[random] = run_duo(args) - records = pd.concat([records,pd.DataFrame([[f"intShift prob{prob} max{maximum}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) - records.to_csv("output/DataAugment/perfs.csv",index = False) - - #Et enfin le rt-shift - for prob in [k/20 for k in range(1,21)]: - args.augment_args[2] = prob - for mean in [5,10,15,20,25,30,40,50,60,70,80,90]: - args.augment_args[5] = mean - for std in [mean/k for k in range(1,11)]: - args.augment_args[6] = std - losses = np.zeros(5); accs = np.zeros(5) - for random in range(5): - args.random_state = random - losses[random], accs[random] = run_duo(args) - records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) - records.to_csv("output/DataAugment/perfs.csv",index = False) + + # #Et enfin le rt-shift + # for prob in [k/20 for k in range(1,21)]: + # args.augment_args[2] = prob + # for mean in [5,10,15,20,25,30,40,50,60,70,80,90]: + # args.augment_args[5] = mean + # for std in [mean/k for k in range(1,11)]: + # args.augment_args[6] = std + # losses = np.zeros(5); accs = np.zeros(5) + # for random in range(5): + # args.random_state = random + # losses[random], accs[random] = run_duo(args) + # records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) + # records.to_csv("output/DataAugment/perfs.csv",index = False) \ No newline at end of file diff --git a/config/config.py b/config/config.py index 63b5999..ae2861d 100644 --- a/config/config.py +++ b/config/config.py @@ -5,7 +5,7 @@ def load_args(): parser = argparse.ArgumentParser() parser.add_argument('--epoches', type=int, default=20) parser.add_argument('--eval_inter', type=int, default=1) - parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0.5,0.1,0.1,0.,7.5]) + parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0,0.1,0.1,0.,7.5]) parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--optim', type = str, default = "Adam") -- GitLab