Skip to content
Snippets Groups Projects
Commit 9f6d850e authored by Léo Calmettes's avatar Léo Calmettes
Browse files

nouveau fichier : AugmentTests.py

parent f11926cd
No related branches found
No related tags found
No related merge requests found
from main import run_duo
from config.config import load_args
import pandas as pd
import numpy as np
if __name__ == "__main__":
args = load_args()
#On commence avec le standard
losses = np.zeros(20); accs = np.zeros(20)
for random in range(20):
args.random_state = random
losses[random], accs[random] = run_duo(args)
records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])
records.to_csv("output/DataAugment/perfs.csv",index = False)
#On continue avec le random_erasing
for prob in [k/20 for k in range(1,21)]:
args.augment_args[0] = prob
for prop in [k/20 for k in range(1,21)]:
args.augment_args[3] = prop
losses = np.zeros(5); accs = np.zeros(5)
for random in range(5):
args.random_state = random
losses[random], accs[random] = run_duo(args)
records = pd.concat([records,pd.DataFrame([[f"erasing prob{prob} prop{prop}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
records.to_csv("output/DataAugment/perfs.csv",index = False)
#Puis le int shift
for prob in [k/20 for k in range(1,21)]:
args.augment_args[1] = prob
for maximum in ([k/10 for k in range(11,20)]+[k for k in range(2,11)]):
args.augment_args[4] = maximum
losses = np.zeros(5); accs = np.zeros(5)
for random in range(5):
args.random_state = random
losses[random], accs[random] = run_duo(args)
records = pd.concat([records,pd.DataFrame([[f"intShift prob{prob} max{maximum}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
records.to_csv("output/DataAugment/perfs.csv",index = False)
#Et enfin le rt-shift
for prob in [k/20 for k in range(1,21)]:
args.augment_args[2] = prob
for mean in [5,10,15,20,25,30,40,50,60,70,80,90]:
args.augment_args[5] = mean
for std in [mean/k for k in range(1,11)]:
args.augment_args[6] = std
losses = np.zeros(5); accs = np.zeros(5)
for random in range(5):
args.random_state = random
losses[random], accs[random] = run_duo(args)
records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
records.to_csv("output/DataAugment/perfs.csv",index = False)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment