From 8f15534fbfefe4add5cfcd448b3c742f5a012e40 Mon Sep 17 00:00:00 2001
From: lcalmettes <leo.calmettes@etu.ec-lyon.fr>
Date: Mon, 12 May 2025 17:03:28 +0200
Subject: [PATCH] =?UTF-8?q?=09modifi=C3=A9=C2=A0:=20=20=20=20=20=20=20=20?=
 =?UTF-8?q?=20AugmentTests.py=20=09modifi=C3=A9=C2=A0:=20=20=20=20=20=20?=
 =?UTF-8?q?=20=20=20config/config.py?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 AugmentTests.py  | 53 ++++++++++++++----------------------------------
 config/config.py |  2 +-
 2 files changed, 16 insertions(+), 39 deletions(-)

diff --git a/AugmentTests.py b/AugmentTests.py
index b71b81e..deecc9a 100644
--- a/AugmentTests.py
+++ b/AugmentTests.py
@@ -12,44 +12,21 @@ if __name__ == "__main__":
         args.random_state = random
         losses[random], accs[random] = run_duo(args)
     records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])
-    records.to_csv("output/DataAugment/perfs.csv",index = False)
+    records.to_csv("output/perfs.csv",index = False)
         
-    #On continue avec le random_erasing
-    for prob in [k/20 for k in range(1,21)]:
-        args.augment_args[0] = prob
-        for prop in [k/20 for k in range(1,21)]:
-            args.augment_args[3] = prop
-            losses = np.zeros(5); accs = np.zeros(5)
-            for random in range(5):
-                args.random_state = random
-                losses[random], accs[random] = run_duo(args)
-            records = pd.concat([records,pd.DataFrame([[f"erasing prob{prob} prop{prop}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
-            records.to_csv("output/DataAugment/perfs.csv",index = False)
-            
-    #Puis le int shift
-    for prob in [k/20 for k in range(1,21)]:
-        args.augment_args[1] = prob
-        for maximum in ([k/10 for k in range(11,20)]+[k for k in range(2,11)]):
-            args.augment_args[4] = maximum
-            losses = np.zeros(5); accs = np.zeros(5)
-            for random in range(5):
-                args.random_state = random
-                losses[random], accs[random] = run_duo(args)
-            records = pd.concat([records,pd.DataFrame([[f"intShift prob{prob} max{maximum}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
-            records.to_csv("output/DataAugment/perfs.csv",index = False)
-            
-    #Et enfin le rt-shift
-    for prob in [k/20 for k in range(1,21)]:
-        args.augment_args[2] = prob
-        for mean in [5,10,15,20,25,30,40,50,60,70,80,90]:
-            args.augment_args[5] = mean
-            for std in [mean/k for k in range(1,11)]:
-                args.augment_args[6] = std
-                losses = np.zeros(5); accs = np.zeros(5)
-                for random in range(5):
-                    args.random_state = random
-                    losses[random], accs[random] = run_duo(args)
-                records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
-                records.to_csv("output/DataAugment/perfs.csv",index = False)
+
+    # #Et enfin le rt-shift
+    # for prob in [k/20 for k in range(1,21)]:
+    #     args.augment_args[2] = prob
+    #     for mean in [5,10,15,20,25,30,40,50,60,70,80,90]:
+    #         args.augment_args[5] = mean
+    #         for std in [mean/k for k in range(1,11)]:
+    #             args.augment_args[6] = std
+    #             losses = np.zeros(5); accs = np.zeros(5)
+    #             for random in range(5):
+    #                 args.random_state = random
+    #                 losses[random], accs[random] = run_duo(args)
+    #             records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
+    #             records.to_csv("output/DataAugment/perfs.csv",index = False)
             
     
\ No newline at end of file
diff --git a/config/config.py b/config/config.py
index 63b5999..ae2861d 100644
--- a/config/config.py
+++ b/config/config.py
@@ -5,7 +5,7 @@ def load_args():
     parser = argparse.ArgumentParser()
     parser.add_argument('--epoches', type=int, default=20)
     parser.add_argument('--eval_inter', type=int, default=1)
-    parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0.5,0.1,0.1,0.,7.5])
+    parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0,0.1,0.1,0.,7.5])
     parser.add_argument('--noise_threshold', type=int, default=0)
     parser.add_argument('--lr', type=float, default=0.001)
     parser.add_argument('--optim', type = str, default = "Adam")
-- 
GitLab