diff --git a/barlow_twin_like/config.py b/barlow_twin_like/config.py
index a9aeb92a95bc0c883ca972f72bfdd30d1b30eaf8..c7d078e3cf4e67f065bd778c667637d8c7ebe35a 100644
--- a/barlow_twin_like/config.py
+++ b/barlow_twin_like/config.py
@@ -9,6 +9,7 @@ def load_args_barlow():
     parser.add_argument('--eval_inter', type=int, default=1)
     parser.add_argument('--test_inter', type=int, default=10)
     parser.add_argument('--lr', type=float, default=0.001)
+    parser.add_argument('--prob_erasing', type=float, default=0.3)
     parser.add_argument('--batch_size', type=int, default=64)
     parser.add_argument('--lambd', type=float, default=0.005)
     parser.add_argument('--opti', type=str, default='adam')
diff --git a/barlow_twin_like/dataset_barlow.py b/barlow_twin_like/dataset_barlow.py
index 3b8c77720633a4537e6caf7be09746170ef0ebc2..d0f93128071f206e35897974b6a53a566c6d43b9 100644
--- a/barlow_twin_like/dataset_barlow.py
+++ b/barlow_twin_like/dataset_barlow.py
@@ -12,6 +12,7 @@ from typing import Callable, cast, Dict, List, Optional, Tuple, Union
 from pathlib import Path
 from collections import OrderedDict
 
+from torch.distributions.utils import probs_to_logits
 from torch.utils.data import WeightedRandomSampler
 
 IMG_EXTENSIONS = ".npy"
@@ -32,6 +33,18 @@ class Random_shift_rt:
         return x
 
 
+class Random_erasing_pixel:
+    def __init__(self,prob):
+        self.prob=prob
+
+    def __call__(self,x):
+        shape = x.size()
+        # Generate a random tensor of size (3, 3) with values 0 or 1
+        prob_array = torch.zeros_like(x)+self.prob
+        value_mask = torch.bernoulli(prob_array)
+
+        return x*value_mask
+
 def npy_loader(path):
     sample = torch.from_numpy(np.load(path))
     sample = sample.unsqueeze(0)
@@ -213,11 +226,12 @@ class ImageFolderDuo(data.Dataset):
     def __len__(self):
         return len(self.imlist)
 
-def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True,ref_dir = None, sampler=None):
+def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, prob_erasing=0, shuffle=True,ref_dir = None, sampler=None):
 
 
     train_transform = transforms.Compose(
         [Random_shift_rt(1,0,15),
+         Random_erasing_pixel(prob=prob_erasing),
         transforms.Resize((224, 224)),
         transforms.Normalize(0.5, 0.5)])
 
diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py
index 02357705e7b03e1d133ff61be56a4007fa5cba64..7d9db6a62de8894938d5c5b7158e30584920bb56 100644
--- a/barlow_twin_like/main.py
+++ b/barlow_twin_like/main.py
@@ -254,6 +254,7 @@ def run():
                       base_dir_val=args.dataset_val_dir,
                       base_dir_test=args.dataset_test_dir,
                       batch_size=args.batch_size,
+                      prob_erasing=args.prob_erasing,
                       ref_dir=args.dataset_ref_dir,
                       sampler=args.sampler))
 
diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py
index c48992f6b1e3b0de54b312ca3d38aa2b55915f52..68b5da14a288d30ebc590ee326b290e80b028c83 100644
--- a/image_ref/dataset_ref.py
+++ b/image_ref/dataset_ref.py
@@ -52,6 +52,19 @@ class Intensity_shift:
         return  torch.mul(x,intensity_ratio_map)
 
 
+class Random_shift_rt:
+    """With a probability prob, shifts verticaly the image depending on a gaussian distribution"""
+
+    def __init__(self, prob, mean, std):
+        self.prob = prob
+        self.mean = torch.tensor(float(mean))
+        self.std = float(std)
+
+    def __call__(self, x):
+        if np.random.rand() < self.prob:
+            shift = torch.normal(self.mean, self.std)
+            return transforms.functional.affine(x, 0, [0, shift], 1, [0, 0])
+        return x
 
 def default_loader(path):
     return Image.open(path).convert('RGB')
@@ -189,8 +202,7 @@ def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuff
 
     ref_transform = transforms.Compose(
         [transforms.Resize((224, 224)),
-         Threshold_noise(0),
-         Log_normalisation(),
+         Random_shift_rt(1,0,10),
          transforms.Normalize(0.5, 0.5)
          ])