diff --git a/AugmentTests.py b/AugmentTests.py index deecc9aa138efd907051048ae898462221f573f6..e0dc7103870a57aa92613d26c704861062266617 100644 --- a/AugmentTests.py +++ b/AugmentTests.py @@ -7,26 +7,27 @@ if __name__ == "__main__": args = load_args() #On commence avec le standard - losses = np.zeros(20); accs = np.zeros(20) - for random in range(20): - args.random_state = random - losses[random], accs[random] = run_duo(args) - records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"]) - records.to_csv("output/perfs.csv",index = False) + # losses = np.zeros(20); accs = np.zeros(20) + # for random in range(20): + # args.random_state = random + # losses[random], accs[random] = run_duo(args) + # records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"]) + # records.to_csv("output/perfs.csv",index = False) - # #Et enfin le rt-shift - # for prob in [k/20 for k in range(1,21)]: - # args.augment_args[2] = prob - # for mean in [5,10,15,20,25,30,40,50,60,70,80,90]: - # args.augment_args[5] = mean - # for std in [mean/k for k in range(1,11)]: - # args.augment_args[6] = std - # losses = np.zeros(5); accs = np.zeros(5) - # for random in range(5): - # args.random_state = random - # losses[random], accs[random] = run_duo(args) - # records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) - # records.to_csv("output/DataAugment/perfs.csv",index = False) - + #Et enfin le rt-shift + records = pd.read_csv("output/perfs.csv") + for prob in [k/5 for k in range(1,6)]: + args.augment_args[2] = prob + mean = 0 + args.augment_args[5] = mean + for std in [k/2 for k in range(5,25,5)]: + args.augment_args[6] = std + losses = np.zeros(5); accs = np.zeros(5) + for random in range(5): + args.random_state = random + losses[random], accs[random] = run_duo(args) + records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean().item(),losses.std().item(),accs.mean().item(),accs.std().item()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) + records.to_csv("output/DataAugment/perfs.csv",index = False) + \ No newline at end of file