From a182465c519a44b16f94018066d65d1b46fe34a3 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 25 Mar 2025 10:29:20 +0100
Subject: [PATCH] stratified dataset split

---
 dataset/dataset.py | 17 +++++++++++------
 main.py            |  3 ++-
 2 files changed, 13 insertions(+), 7 deletions(-)

diff --git a/dataset/dataset.py b/dataset/dataset.py
index 1512488..1365db8 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -1,3 +1,5 @@
+import random
+
 import numpy as np
 import torch
 import torchvision
@@ -9,7 +11,7 @@ import os.path
 from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
 from pathlib import Path
 from collections import OrderedDict
-
+from sklearn.model_selection import train_test_split
 IMG_EXTENSIONS = ".npy"
 
 class Threshold_noise:
@@ -54,11 +56,13 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
     print('Default val transform')
     train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform)
     val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform)
-    generator1 = torch.Generator().manual_seed(42)
-    indices = torch.randperm(len(train_dataset),generator=generator1)
-    val_size = len(train_dataset) // 5
-    train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
-    val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
+
+    #Same seed to avoid overlap while having different transforms
+    seed = random.randint(0,1000)
+    train_dataset, _ = train_test_split(train_dataset, test_size=None, train_size=None, random_state=seed, shuffle=True,
+                                             stratify=True)
+    _, val_dataset = train_test_split(val_dataset, test_size=None, train_size=None, random_state=seed, shuffle=True,
+                                             stratify=True)
 
     data_loader_train = data.DataLoader(
         dataset=train_dataset,
@@ -198,6 +202,7 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
          Log_normalisation(),
          transforms.Normalize(0.5, 0.5)])
     print('Default val transform')
+
     train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform)
     val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform)
     generator1 = torch.Generator().manual_seed(42)
diff --git a/main.py b/main.py
index 30f7843..4be686c 100644
--- a/main.py
+++ b/main.py
@@ -257,6 +257,7 @@ def load_model(model, path):
 if __name__ == '__main__':
     args = load_args()
     if args.model_type=='duo':
-        run_duo(args)
+        # run_duo(args)
+        data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size)
     else :
         run(args)
\ No newline at end of file
-- 
GitLab