diff --git a/dataset/dataset.py b/dataset/dataset.py
index 435b2044c07fb404cde98fdfe256b3b4f729d66a..ec59fead2ec74552a18246e363b644762b764d7a 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -28,18 +28,33 @@ class Random_shift_rt():
 
 def load_data(base_dir, batch_size, shuffle=True, transform=None):
     if transform is None :
-        transform = transforms.Compose(
+        train_transform = transforms.Compose(
             [transforms.Grayscale(num_output_channels=1),
              transforms.ToTensor(),
              transforms.Resize((224,224)),
+             Threshold_noise(500),
              Log_normalisation(),
              transforms.Normalize(0.5, 0.5)])
-        print('default transform')
-    dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=transform)
+        print('Default train transform')
+
+        val_transform = transforms.Compose(
+            [transforms.Grayscale(num_output_channels=1),
+             transforms.ToTensor(),
+             transforms.Resize((224,224)),
+             Threshold_noise(500),
+             Log_normalisation(),
+             transforms.Normalize(0.5, 0.5)])
+        print('Default val transform')
+    train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform)
+    val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform)
     generator1 = torch.Generator().manual_seed(42)
-    data_train, data_test = random_split(dataset, [0.8, 0.2], generator=generator1)
+    indices = torch.randperm(len(train_dataset),generator=generator1)
+    val_size = len(train_dataset) // 5
+    train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
+    val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
+
     data_loader_train = DataLoader(
-        dataset=data_train,
+        dataset=train_dataset,
         batch_size=batch_size,
         shuffle=shuffle,
         num_workers=0,
@@ -48,7 +63,7 @@ def load_data(base_dir, batch_size, shuffle=True, transform=None):
     )
 
     data_loader_test = DataLoader(
-        dataset=data_test,
+        dataset=val_dataset,
         batch_size=batch_size,
         shuffle=shuffle,
         num_workers=0,
diff --git a/output.png b/output.png
deleted file mode 100644
index c06de6a4542828356cb331664b0b1017ef83949e..0000000000000000000000000000000000000000
Binary files a/output.png and /dev/null differ