From a182465c519a44b16f94018066d65d1b46fe34a3 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 25 Mar 2025 10:29:20 +0100 Subject: [PATCH] stratified dataset split --- dataset/dataset.py | 17 +++++++++++------ main.py | 3 ++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/dataset/dataset.py b/dataset/dataset.py index 1512488..1365db8 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -1,3 +1,5 @@ +import random + import numpy as np import torch import torchvision @@ -9,7 +11,7 @@ import os.path from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from pathlib import Path from collections import OrderedDict - +from sklearn.model_selection import train_test_split IMG_EXTENSIONS = ".npy" class Threshold_noise: @@ -54,11 +56,13 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): print('Default val transform') train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform) val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform) - generator1 = torch.Generator().manual_seed(42) - indices = torch.randperm(len(train_dataset),generator=generator1) - val_size = len(train_dataset) // 5 - train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size]) - val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:]) + + #Same seed to avoid overlap while having different transforms + seed = random.randint(0,1000) + train_dataset, _ = train_test_split(train_dataset, test_size=None, train_size=None, random_state=seed, shuffle=True, + stratify=True) + _, val_dataset = train_test_split(val_dataset, test_size=None, train_size=None, random_state=seed, shuffle=True, + stratify=True) data_loader_train = data.DataLoader( dataset=train_dataset, @@ -198,6 +202,7 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0): Log_normalisation(), transforms.Normalize(0.5, 0.5)]) print('Default val transform') + train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform) val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform) generator1 = torch.Generator().manual_seed(42) diff --git a/main.py b/main.py index 30f7843..4be686c 100644 --- a/main.py +++ b/main.py @@ -257,6 +257,7 @@ def load_model(model, path): if __name__ == '__main__': args = load_args() if args.model_type=='duo': - run_duo(args) + # run_duo(args) + data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size) else : run(args) \ No newline at end of file -- GitLab