diff --git a/AugmentTests.py b/AugmentTests.py index b71b81e8798d2109bdda31cf257504ecbe7eb53a..deecc9aa138efd907051048ae898462221f573f6 100644 --- a/AugmentTests.py +++ b/AugmentTests.py @@ -12,44 +12,21 @@ if __name__ == "__main__": args.random_state = random losses[random], accs[random] = run_duo(args) records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"]) - records.to_csv("output/DataAugment/perfs.csv",index = False) + records.to_csv("output/perfs.csv",index = False) - #On continue avec le random_erasing - for prob in [k/20 for k in range(1,21)]: - args.augment_args[0] = prob - for prop in [k/20 for k in range(1,21)]: - args.augment_args[3] = prop - losses = np.zeros(5); accs = np.zeros(5) - for random in range(5): - args.random_state = random - losses[random], accs[random] = run_duo(args) - records = pd.concat([records,pd.DataFrame([[f"erasing prob{prob} prop{prop}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) - records.to_csv("output/DataAugment/perfs.csv",index = False) - - #Puis le int shift - for prob in [k/20 for k in range(1,21)]: - args.augment_args[1] = prob - for maximum in ([k/10 for k in range(11,20)]+[k for k in range(2,11)]): - args.augment_args[4] = maximum - losses = np.zeros(5); accs = np.zeros(5) - for random in range(5): - args.random_state = random - losses[random], accs[random] = run_duo(args) - records = pd.concat([records,pd.DataFrame([[f"intShift prob{prob} max{maximum}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) - records.to_csv("output/DataAugment/perfs.csv",index = False) - - #Et enfin le rt-shift - for prob in [k/20 for k in range(1,21)]: - args.augment_args[2] = prob - for mean in [5,10,15,20,25,30,40,50,60,70,80,90]: - args.augment_args[5] = mean - for std in [mean/k for k in range(1,11)]: - args.augment_args[6] = std - losses = np.zeros(5); accs = np.zeros(5) - for random in range(5): - args.random_state = random - losses[random], accs[random] = run_duo(args) - records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) - records.to_csv("output/DataAugment/perfs.csv",index = False) + + # #Et enfin le rt-shift + # for prob in [k/20 for k in range(1,21)]: + # args.augment_args[2] = prob + # for mean in [5,10,15,20,25,30,40,50,60,70,80,90]: + # args.augment_args[5] = mean + # for std in [mean/k for k in range(1,11)]: + # args.augment_args[6] = std + # losses = np.zeros(5); accs = np.zeros(5) + # for random in range(5): + # args.random_state = random + # losses[random], accs[random] = run_duo(args) + # records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) + # records.to_csv("output/DataAugment/perfs.csv",index = False) \ No newline at end of file diff --git a/config/config.py b/config/config.py index 63b5999c22eef360fb137a88c2320ce3282c589d..ae2861d93bbeea900fa59dfbbfb0381c937af0a2 100644 --- a/config/config.py +++ b/config/config.py @@ -5,7 +5,7 @@ def load_args(): parser = argparse.ArgumentParser() parser.add_argument('--epoches', type=int, default=20) parser.add_argument('--eval_inter', type=int, default=1) - parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0.5,0.1,0.1,0.,7.5]) + parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0,0.1,0.1,0.,7.5]) parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--optim', type = str, default = "Adam")