diff --git a/config/config.py b/config/config.py
index bad9ecbb069ec9c572c7594f4be0f9a3aec353b1..d898b6bbef2f8ea9c6a9e3bebd3380057cba4ca1 100644
--- a/config/config.py
+++ b/config/config.py
@@ -7,6 +7,7 @@ def load_args():
     parser.add_argument('--epoches', type=int, default=100)
     parser.add_argument('--save_inter', type=int, default=50)
     parser.add_argument('--eval_inter', type=int, default=1)
+    parser.add_argument('--noise_threshold', type=int, default=0)
     parser.add_argument('--lr', type=float, default=0.001)
     parser.add_argument('--batch_size', type=int, default=64)
     parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training')
diff --git a/dataset/dataset.py b/dataset/dataset.py
index ec59fead2ec74552a18246e363b644762b764d7a..6f807d03000d5328897ff4d009413ab6799a6a0a 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -26,25 +26,24 @@ class Random_shift_rt():
     pass
 
 
-def load_data(base_dir, batch_size, shuffle=True, transform=None):
-    if transform is None :
-        train_transform = transforms.Compose(
-            [transforms.Grayscale(num_output_channels=1),
-             transforms.ToTensor(),
-             transforms.Resize((224,224)),
-             Threshold_noise(500),
-             Log_normalisation(),
-             transforms.Normalize(0.5, 0.5)])
-        print('Default train transform')
+def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
+    train_transform = transforms.Compose(
+        [transforms.Grayscale(num_output_channels=1),
+         transforms.ToTensor(),
+         transforms.Resize((224,224)),
+         Threshold_noise(noise_threshold),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default train transform')
 
-        val_transform = transforms.Compose(
-            [transforms.Grayscale(num_output_channels=1),
-             transforms.ToTensor(),
-             transforms.Resize((224,224)),
-             Threshold_noise(500),
-             Log_normalisation(),
-             transforms.Normalize(0.5, 0.5)])
-        print('Default val transform')
+    val_transform = transforms.Compose(
+        [transforms.Grayscale(num_output_channels=1),
+         transforms.ToTensor(),
+         transforms.Resize((224,224)),
+         Threshold_noise(noise_threshold),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default val transform')
     train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform)
     val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform)
     generator1 = torch.Generator().manual_seed(42)