Skip to content
Snippets Groups Projects
Commit 8f15534f authored by Léo Calmettes's avatar Léo Calmettes
Browse files

modifié : AugmentTests.py

	modifié :         config/config.py
parent 5dd0a311
No related branches found
No related tags found
No related merge requests found
...@@ -12,44 +12,21 @@ if __name__ == "__main__": ...@@ -12,44 +12,21 @@ if __name__ == "__main__":
args.random_state = random args.random_state = random
losses[random], accs[random] = run_duo(args) losses[random], accs[random] = run_duo(args)
records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"]) records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])
records.to_csv("output/DataAugment/perfs.csv",index = False) records.to_csv("output/perfs.csv",index = False)
#On continue avec le random_erasing
for prob in [k/20 for k in range(1,21)]: # #Et enfin le rt-shift
args.augment_args[0] = prob # for prob in [k/20 for k in range(1,21)]:
for prop in [k/20 for k in range(1,21)]: # args.augment_args[2] = prob
args.augment_args[3] = prop # for mean in [5,10,15,20,25,30,40,50,60,70,80,90]:
losses = np.zeros(5); accs = np.zeros(5) # args.augment_args[5] = mean
for random in range(5): # for std in [mean/k for k in range(1,11)]:
args.random_state = random # args.augment_args[6] = std
losses[random], accs[random] = run_duo(args) # losses = np.zeros(5); accs = np.zeros(5)
records = pd.concat([records,pd.DataFrame([[f"erasing prob{prob} prop{prop}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])]) # for random in range(5):
records.to_csv("output/DataAugment/perfs.csv",index = False) # args.random_state = random
# losses[random], accs[random] = run_duo(args)
#Puis le int shift # records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
for prob in [k/20 for k in range(1,21)]: # records.to_csv("output/DataAugment/perfs.csv",index = False)
args.augment_args[1] = prob
for maximum in ([k/10 for k in range(11,20)]+[k for k in range(2,11)]):
args.augment_args[4] = maximum
losses = np.zeros(5); accs = np.zeros(5)
for random in range(5):
args.random_state = random
losses[random], accs[random] = run_duo(args)
records = pd.concat([records,pd.DataFrame([[f"intShift prob{prob} max{maximum}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
records.to_csv("output/DataAugment/perfs.csv",index = False)
#Et enfin le rt-shift
for prob in [k/20 for k in range(1,21)]:
args.augment_args[2] = prob
for mean in [5,10,15,20,25,30,40,50,60,70,80,90]:
args.augment_args[5] = mean
for std in [mean/k for k in range(1,11)]:
args.augment_args[6] = std
losses = np.zeros(5); accs = np.zeros(5)
for random in range(5):
args.random_state = random
losses[random], accs[random] = run_duo(args)
records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
records.to_csv("output/DataAugment/perfs.csv",index = False)
\ No newline at end of file
...@@ -5,7 +5,7 @@ def load_args(): ...@@ -5,7 +5,7 @@ def load_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--epoches', type=int, default=20) parser.add_argument('--epoches', type=int, default=20)
parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0.5,0.1,0.1,0.,7.5]) parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0,0.1,0.1,0.,7.5])
parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--noise_threshold', type=int, default=0)
parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--optim', type = str, default = "Adam") parser.add_argument('--optim', type = str, default = "Adam")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment