From 9f6d850ed7e1149d47ecb6fbb24e9896b5ef662d Mon Sep 17 00:00:00 2001
From: lcalmettes <leo.calmettes@etu.ec-lyon.fr>
Date: Wed, 7 May 2025 10:18:50 +0200
Subject: [PATCH] =?UTF-8?q?=09nouveau=20fichier=C2=A0:=20AugmentTests.py?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 AugmentTests.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 55 insertions(+)
 create mode 100644 AugmentTests.py

diff --git a/AugmentTests.py b/AugmentTests.py
new file mode 100644
index 0000000..b71b81e
--- /dev/null
+++ b/AugmentTests.py
@@ -0,0 +1,55 @@
+from main import run_duo
+from config.config import load_args
+import pandas as pd
+import numpy as np
+
+if __name__ == "__main__":
+    args = load_args()
+    
+    #On commence avec le standard
+    losses = np.zeros(20); accs = np.zeros(20)
+    for random in range(20):
+        args.random_state = random
+        losses[random], accs[random] = run_duo(args)
+    records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])
+    records.to_csv("output/DataAugment/perfs.csv",index = False)
+        
+    #On continue avec le random_erasing
+    for prob in [k/20 for k in range(1,21)]:
+        args.augment_args[0] = prob
+        for prop in [k/20 for k in range(1,21)]:
+            args.augment_args[3] = prop
+            losses = np.zeros(5); accs = np.zeros(5)
+            for random in range(5):
+                args.random_state = random
+                losses[random], accs[random] = run_duo(args)
+            records = pd.concat([records,pd.DataFrame([[f"erasing prob{prob} prop{prop}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
+            records.to_csv("output/DataAugment/perfs.csv",index = False)
+            
+    #Puis le int shift
+    for prob in [k/20 for k in range(1,21)]:
+        args.augment_args[1] = prob
+        for maximum in ([k/10 for k in range(11,20)]+[k for k in range(2,11)]):
+            args.augment_args[4] = maximum
+            losses = np.zeros(5); accs = np.zeros(5)
+            for random in range(5):
+                args.random_state = random
+                losses[random], accs[random] = run_duo(args)
+            records = pd.concat([records,pd.DataFrame([[f"intShift prob{prob} max{maximum}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
+            records.to_csv("output/DataAugment/perfs.csv",index = False)
+            
+    #Et enfin le rt-shift
+    for prob in [k/20 for k in range(1,21)]:
+        args.augment_args[2] = prob
+        for mean in [5,10,15,20,25,30,40,50,60,70,80,90]:
+            args.augment_args[5] = mean
+            for std in [mean/k for k in range(1,11)]:
+                args.augment_args[6] = std
+                losses = np.zeros(5); accs = np.zeros(5)
+                for random in range(5):
+                    args.random_state = random
+                    losses[random], accs[random] = run_duo(args)
+                records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
+                records.to_csv("output/DataAugment/perfs.csv",index = False)
+            
+    
\ No newline at end of file
-- 
GitLab