diff --git a/dataset/dataset.py b/dataset/dataset.py
index 1512488f9aa1f3b9873029c8c2053ca3b10ce310..1365db861c8ff4e5b1fc2fdce701af878c13da59 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -1,3 +1,5 @@
+import random
+
 import numpy as np
 import torch
 import torchvision
@@ -9,7 +11,7 @@ import os.path
 from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
 from pathlib import Path
 from collections import OrderedDict
-
+from sklearn.model_selection import train_test_split
 IMG_EXTENSIONS = ".npy"
 
 class Threshold_noise:
@@ -54,11 +56,13 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
     print('Default val transform')
     train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform)
     val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform)
-    generator1 = torch.Generator().manual_seed(42)
-    indices = torch.randperm(len(train_dataset),generator=generator1)
-    val_size = len(train_dataset) // 5
-    train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
-    val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
+
+    #Same seed to avoid overlap while having different transforms
+    seed = random.randint(0,1000)
+    train_dataset, _ = train_test_split(train_dataset, test_size=None, train_size=None, random_state=seed, shuffle=True,
+                                             stratify=True)
+    _, val_dataset = train_test_split(val_dataset, test_size=None, train_size=None, random_state=seed, shuffle=True,
+                                             stratify=True)
 
     data_loader_train = data.DataLoader(
         dataset=train_dataset,
@@ -198,6 +202,7 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
          Log_normalisation(),
          transforms.Normalize(0.5, 0.5)])
     print('Default val transform')
+
     train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform)
     val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform)
     generator1 = torch.Generator().manual_seed(42)
diff --git a/main.py b/main.py
index 30f7843b0f6285ad5c47bf5d31e0176e2666e70e..4be686c5c7efb36825fc341afcf2c4d2ab0c4e87 100644
--- a/main.py
+++ b/main.py
@@ -257,6 +257,7 @@ def load_model(model, path):
 if __name__ == '__main__':
     args = load_args()
     if args.model_type=='duo':
-        run_duo(args)
+        # run_duo(args)
+        data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size)
     else :
         run(args)
\ No newline at end of file