diff --git a/config/config.py b/config/config.py
index 9673a42fe6e7bf45d60d16bfa47b378f35db4a5e..dfb2e26c35e2217d461b76de47c91a6d43e50fc6 100644
--- a/config/config.py
+++ b/config/config.py
@@ -15,6 +15,7 @@ def load_args():
     parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff/npy_image/train_data')
     parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff/npy_image/test_data')
     parser.add_argument('--dataset_test_dir', type=str, default=None)
+    parser.add_argument('--outname', type=str, default=None)
     parser.add_argument('--output', type=str, default='output/out.csv')
     parser.add_argument('--save_path', type=str, default='output/best_model.pt')
     parser.add_argument('--pretrain_path', type=str, default=None)
diff --git a/dataset/dataset.py b/dataset/dataset.py
index 32e96f4396eccfe938b5fc561bbab19e93ba6747..6130f5db9f96cf62bbfe45b95ae6b07bb482828c 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -10,9 +10,64 @@ import os.path
 from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
 from pathlib import Path
 from collections import OrderedDict
+from skimage import measure
 from sklearn.model_selection import train_test_split
 IMG_EXTENSIONS = ".npy"
 
+
+class Random_erasing2:
+    """with a probability prob, erases a proportion prop of the image"""
+
+    def __init__(self, prob, prop):
+        self.prob = prob
+        self.prop = prop
+
+    def __call__(self, x):
+        if np.random.rand() < self.prob:
+            labels = measure.label(x.numpy() > 0, connectivity=1)
+            regions = measure.regionprops(labels)
+            pics_suppr = np.random.rand(len(regions)) < self.prop
+            for k in range(len(regions)):
+                if pics_suppr[k]:
+                    try:
+                        _, y1, x1, _, y2, x2 = regions[k].bbox
+                    except:
+                        raise Exception(regions[k].bbox)
+                    x[:, y1:y2, x1:x2] *= regions[k].image == False
+            return x
+
+        return x
+
+
+class Random_int_noise:
+    """With a probability prob, adds a gaussian noise to the image """
+
+    def __init__(self, prob, maximum):
+        self.prob = prob
+        self.minimum = 1 / maximum
+        self.delta = maximum - self.minimum
+
+    def __call__(self, x):
+        if np.random.rand() < self.prob:
+            return x * (self.minimum + torch.rand_like(x) * self.delta)
+        return x
+
+
+class Random_shift_rt:
+    """With a probability prob, shifts verticaly the image depending on a gaussian distribution"""
+
+    def __init__(self, prob, mean, std):
+        self.prob = prob
+        self.mean = torch.tensor(float(mean))
+        self.std = float(std)
+
+    def __call__(self, x):
+        if np.random.rand() < self.prob:
+            shift = torch.normal(self.mean, self.std)
+            return transforms.functional.affine(x, 0, [0, shift], 1, [0, 0])
+        return x
+
+
 class Threshold_noise:
     """Remove intensities under given threshold"""
 
@@ -31,9 +86,6 @@ class Log_normalisation:
     def __call__(self, x):
         return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon)
 
-class Random_shift_rt:
-    pass
-
 class Repeat:
     "Repeat tensor along new dim"
 
diff --git a/main.py b/main.py
index b4832c5f05f3afd29b5023e9fc8722ae23a35e9b..c6bfe16124b813475e01c2a7493fcef4e8ca0622 100644
--- a/main.py
+++ b/main.py
@@ -97,11 +97,11 @@ def run(args):
     plt.plot(train_acc)
     plt.ylim(0, 1.05)
     plt.show()
-    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
+    plt.savefig('output/training_plot_{}.png'.format(args.outname))
 
     #load and evaluated best model
     load_model(model, args.save_path)
-    make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
+    make_prediction(model,data_test, 'output/confusion_matrix_{}.png'.format(args.outname))
 
 
 def make_prediction(model, data, f_name):