diff --git a/AugmentTests.py b/AugmentTests.py
index b8dd7787d74b659d05913c4b7e3fe2e7dfc37e7b..cf8d72850b249e6261b7da0e538c498ccc6e10d7 100644
--- a/AugmentTests.py
+++ b/AugmentTests.py
@@ -7,12 +7,12 @@ if __name__ == "__main__":
     args = load_args()
     
     #On commence avec le standard
-    # losses = np.zeros(20); accs = np.zeros(20)
-    # for random in range(20):
-    #     args.random_state = random
-    #     losses[random], accs[random] = run_duo(args)
-    # records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])
-    # records.to_csv("output/perfs.csv",index = False)
+    losses = np.zeros(20); accs = np.zeros(20)
+    for random in range(20):
+        args.random_state = random
+        losses[random], accs[random] = run_duo(args)
+    records = pd.DataFrame([["Standard",losses.mean(),losses.std(),accs.mean(),accs.std()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])
+    records.to_csv("output/perfs.csv",index = False)
         
 
     # #Ensuite le rt-shift
@@ -31,18 +31,18 @@ if __name__ == "__main__":
     #         records = pd.concat([records,pd.DataFrame([[f"rtShift prob{prob} mean{mean} std{std}",losses.mean().item(),losses.std().item(),accs.mean().item(),accs.std().item()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
     #         records.to_csv("output/perfs.csv",index = False)
         
-    #Et enfin le int_shift
-    nTests = 5
-    records = pd.read_csv("output/perfs.csv")
-    for prob in [k/5 for k in range(1,6)]:
-        args.augment_args[1] = prob
-        for borne_max in [1.25,1.5,1.75,2.,2.25,2.5]:
-            args.augment_args[4] = borne_max
-            losses = np.zeros(nTests); accs = np.zeros(nTests)
-            for random in range(nTests):
-                args.random_state = random
-                losses[random], accs[random] = run_duo(args)
-            records = pd.concat([records,pd.DataFrame([[f"intShift prob{prob} max{borne_max}",losses.mean().item(),losses.std().item(),accs.mean().item(),accs.std().item()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
-            records.to_csv("output/perfs.csv",index = False)
+    # #Et enfin le int_shift
+    # nTests = 5
+    # records = pd.read_csv("output/perfs.csv")
+    # for prob in [k/5 for k in range(1,6)]:
+    #     args.augment_args[1] = prob
+    #     for borne_max in [1.25,1.5,1.75,2.,2.25,2.5]:
+    #         args.augment_args[4] = borne_max
+    #         losses = np.zeros(nTests); accs = np.zeros(nTests)
+    #         for random in range(nTests):
+    #             args.random_state = random
+    #             losses[random], accs[random] = run_duo(args)
+    #         records = pd.concat([records,pd.DataFrame([[f"intShift prob{prob} max{borne_max}",losses.mean().item(),losses.std().item(),accs.mean().item(),accs.std().item()]],columns = ["Augmentation","mu_loss","std_loss","mu_acc","std_acc"])])
+    #         records.to_csv("output/perfs.csv",index = False)
         
     
\ No newline at end of file
diff --git a/config/config.py b/config/config.py
index ae2861d93bbeea900fa59dfbbfb0381c937af0a2..8f8a0a0d85dfbe799c478c7705c7baeeeaede269 100644
--- a/config/config.py
+++ b/config/config.py
@@ -5,7 +5,7 @@ def load_args():
     parser = argparse.ArgumentParser()
     parser.add_argument('--epoches', type=int, default=20)
     parser.add_argument('--eval_inter', type=int, default=1)
-    parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0,0.1,0.1,0.,7.5])
+    parser.add_argument('--augment_args', nargs = '+', type = float, default = [1,0,0,0.99,0.1,0.,7.5])
     parser.add_argument('--noise_threshold', type=int, default=0)
     parser.add_argument('--lr', type=float, default=0.001)
     parser.add_argument('--optim', type = str, default = "Adam")
@@ -16,11 +16,11 @@ def load_args():
     parser.add_argument('--batch_size', type=int, default=128)
     parser.add_argument('--model', type=str, default='ResNet18')
     parser.add_argument('--model_type', type=str, default='duo')
-    parser.add_argument('--dataset_dir', type=str, default='data/fused_data/species_training')
+    parser.add_argument('--dataset_dir', type=str, default='data/fused_data/clean_species')
     parser.add_argument('--output', type=str, default='output/out.csv')
     parser.add_argument('--save_path', type=str, default='output/best_model.pt')
     parser.add_argument('--pretrain_path', type=str, default=None)
     parser.add_argument('--random_state',type = int, default = 42)
     args = parser.parse_args()
 
-    return args
+    return args
\ No newline at end of file
diff --git a/dataset/dataset.py b/dataset/dataset.py
index 1f1fcd664b991698179882306ab8a0d3b04c3614..b17eddeb64b8f5ff82394bac06df00c0cd4c5410 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -10,6 +10,7 @@ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
 from pathlib import Path
 from collections import OrderedDict
 from sklearn.model_selection import train_test_split
+from skimage import measure
 IMG_EXTENSIONS = ".npy"
 
 class Threshold_noise:
@@ -32,16 +33,56 @@ class Log_normalisation:
 
 class Random_erasing:
     """with a probability prob, erases a proportion prop of the image"""
-    def __init__(self, prob, prop):
+    def __init__(self, prob, prop, threshold,hK = 15, wK = 3):
         self.prob = prob
-        self.prop = prop
-    
+        self.hK = hK
+        self.wK = wK
+        self.prop = prop #qtt de pics supprimés
+        self.thresh = threshold #Il s'agit du threshold de pic, pas celui du noise, mais peut etre qu'on peut harmoniser les deux ?
+        self.kernelP = torch.nn.functional.pad(torch.ones(1,1,hK,wK),(1,1,1,1)).int() #Le padding augmente de +1 par rapport à la hKwK considérée
+        self.kernelP[0,0,1:-1,wK//2+1] = 2
+        self.kernelG = torch.cat(
+            [self.kernelP,
+             transforms.functional.affine(self.kernelP,0,[0,1],1,[0,0]), #transposé vers le bas (rt)
+             transforms.functional.affine(self.kernelP,0,[0,-1],1,[0,0]),#transposé vers le haut (rt)
+             transforms.functional.affine(self.kernelP,0,[1,0],1,[0,0]), #transposé vers la droite (m/z)
+             transforms.functional.affine(self.kernelP,0,[-1,0],1,[0,0]),#transposé vers la gauche (m/z)
+            ])
+        self.floutage = torch.ones(1,1,hK+2,wK+2)-torch.nn.functional.pad(torch.ones(1,1,hK,wK),(1,1,1,1)).double()
+        self.floutage /= torch.sum(self.floutage)
     def __call__(self,x):
-        if np.random.rand() > self.prob:
-            return x*(torch.rand_like(x) > self.prop)
+        if np.random.rand() < self.prob:
+            thresh_reached = (x>self.thresh).int()
+            eval_centrage = torch.nn.functional.conv2d(thresh_reached.unsqueeze(0),self.kernelG,padding=(self.hK//2+1,self.wK//2+1)).squeeze()
+            centres = ((eval_centrage[0]>eval_centrage[1])*(eval_centrage[0]>=eval_centrage[2])*(eval_centrage[0]>eval_centrage[3])*(eval_centrage[0]>=eval_centrage[4]))
+            pics_suppr = (torch.rand_like(x) < self.prop)*centres
+            flou = torch.nn.functional.conv2d(x.unsqueeze(0),self.floutage,padding = (self.hK//2+1,self.wK//2+1))
+            zones_suppr = torch.nn.functional.conv2d(pics_suppr*flou,torch.ones(1,1,self.hK,self.wK).double(),padding = (self.hK//2,self.wK//2)).squeeze()
+            
+            return thresh_reached, eval_centrage[0], centres, zones_suppr, torch.where((zones_suppr==0),x,0)
         return x
         
+class Random_erasing2:
+    """with a probability prob, erases a proportion prop of the image"""
+    def __init__(self, prob, prop):
+        self.prob = prob
+        self.prop = prop
         
+    def __call__(self,x):
+        if np.random.rand() < self.prob:
+            labels = measure.label(x.numpy()>0,connectivity=1)
+            regions = measure.regionprops(labels)
+            pics_suppr = np.random.rand(len(regions))<self.prop
+            for k in range(len(regions)):
+                if pics_suppr[k]:
+                    try:
+                        y1,x1,y2,x2 = regions[k].bbox
+                    except:
+                        raise Exception(regions[k].bbox)
+                    x[y1:y2,x1:x2] *= regions[k].image== False
+            return x
+                    
+        return x
 class Random_int_noise:
     """With a probability prob, adds a gaussian noise to the image """
     def __init__(self, prob, maximum):
@@ -50,7 +91,7 @@ class Random_int_noise:
         self.delta = maximum-self.minimum
         
     def __call__(self, x):
-        if np.random.rand() > self.prob:
+        if np.random.rand() < self.prob:
             return x*(self.minimum + torch.rand_like(x)*self.delta)
         return x
     
@@ -62,7 +103,7 @@ class Random_shift_rt:
         self.std = float(std)
         
     def __call__(self,x):
-        if np.random.rand()>self.prob:
+        if np.random.rand()<self.prob:
             shift = torch.normal(self.mean,self.std)
             return transforms.functional.affine(x,0,[0,shift],1,[0,0])
         return x
@@ -219,8 +260,8 @@ class ImageFolderDuo(data.Dataset):
 
 def load_data_duo(base_dir, batch_size, args, shuffle=True):
     train_transform = transforms.Compose(
-        [#Random_erasing(args.augment_args[0], args.augment_args[3]),
-         Random_int_noise(args.augment_args[1], args.augment_args[4]),
+        [#Random_erasing2(args.augment_args[0], args.augment_args[3]),
+         #Random_int_noise(args.augment_args[1], args.augment_args[4]),
          #Random_shift_rt(args.augment_args[2], args.augment_args[5], args.augment_args[6]),
          transforms.Resize((224, 224)),
          Threshold_noise(args.noise_threshold),
@@ -252,7 +293,7 @@ def load_data_duo(base_dir, batch_size, args, shuffle=True):
         dataset=train_dataset,
         batch_size=batch_size,
         shuffle=shuffle,
-        num_workers=8,
+        num_workers=16,
         collate_fn=None,
         pin_memory=False,
     )
@@ -260,11 +301,10 @@ def load_data_duo(base_dir, batch_size, args, shuffle=True):
         dataset=val_dataset,
         batch_size=batch_size,
         shuffle=shuffle,
-        num_workers=8,
+        num_workers=16,
         collate_fn=None,
         pin_memory=False,
     )
-
     return data_loader_train, data_loader_test
 
 
diff --git a/main.py b/main.py
index 7090c536d730452c7789e9bd9e338863838114bc..d01f86cc9a25929f8cc1f7bedc89e3863e34e33a 100644
--- a/main.py
+++ b/main.py
@@ -236,6 +236,12 @@ def run_duo(args):
     else:
         raise Exception("Unusual args.optim")
     #train model
+    loss, acc = test_duo(model,data_train,loss_function,-1)
+    train_loss.append(loss)
+    train_acc.append(acc)
+    loss, acc = test_duo(model,data_test,loss_function,-1)
+    val_loss.append(loss)
+    val_acc.append(acc)
     for e in range(args.epoches):
         loss, acc = train_duo(model,data_train,optimizer,loss_function,e)
         train_loss.append(loss)
@@ -254,11 +260,11 @@ def run_duo(args):
     plt.plot(val_loss)
     plt.plot(train_loss)
     plt.plot(train_loss)
-    plt.savefig(f'output/dataAugments/intShift prob {args.augment_args[1]} borne_max {args.augment_args[4]}.png')
+    plt.savefig('output/species_clean_best_param.png')
     #plt.savefig(f'output/Atraining_plot_model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png')
     #load and evaluate best model
     load_model(model, args.save_path)
-    make_prediction_duo(model,data_test, f'output/dataAugments/intShift prob {args.augment_args[1]} borne_max {args.augment_args[4]}.png')
+    make_prediction_duo(model,data_test, 'output/species_clean_best_param.png')
     #make_prediction_duo(model,data_test, f'output/Amodel_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png')
     return best_loss,best_acc
 
diff --git a/output/best_model.pt b/output/best_model.pt
index c8679fd4bcb6b10ac9c7131d7fe1e76e6f1cf1d9..7f213f61ad3341d6cef22b361900807adc5700f6 100644
Binary files a/output/best_model.pt and b/output/best_model.pt differ