Skip to content
Snippets Groups Projects
Commit a182465c authored by Schneider Leo's avatar Schneider Leo
Browse files

stratified dataset split

parent b7a12ac7
No related branches found
No related tags found
No related merge requests found
import random
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
...@@ -9,7 +11,7 @@ import os.path ...@@ -9,7 +11,7 @@ import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from pathlib import Path from pathlib import Path
from collections import OrderedDict from collections import OrderedDict
from sklearn.model_selection import train_test_split
IMG_EXTENSIONS = ".npy" IMG_EXTENSIONS = ".npy"
class Threshold_noise: class Threshold_noise:
...@@ -54,11 +56,13 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): ...@@ -54,11 +56,13 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
print('Default val transform') print('Default val transform')
train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform) train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform)
val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform) val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform)
generator1 = torch.Generator().manual_seed(42)
indices = torch.randperm(len(train_dataset),generator=generator1) #Same seed to avoid overlap while having different transforms
val_size = len(train_dataset) // 5 seed = random.randint(0,1000)
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size]) train_dataset, _ = train_test_split(train_dataset, test_size=None, train_size=None, random_state=seed, shuffle=True,
val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:]) stratify=True)
_, val_dataset = train_test_split(val_dataset, test_size=None, train_size=None, random_state=seed, shuffle=True,
stratify=True)
data_loader_train = data.DataLoader( data_loader_train = data.DataLoader(
dataset=train_dataset, dataset=train_dataset,
...@@ -198,6 +202,7 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0): ...@@ -198,6 +202,7 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
Log_normalisation(), Log_normalisation(),
transforms.Normalize(0.5, 0.5)]) transforms.Normalize(0.5, 0.5)])
print('Default val transform') print('Default val transform')
train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform) train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform)
val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform) val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform)
generator1 = torch.Generator().manual_seed(42) generator1 = torch.Generator().manual_seed(42)
......
...@@ -257,6 +257,7 @@ def load_model(model, path): ...@@ -257,6 +257,7 @@ def load_model(model, path):
if __name__ == '__main__': if __name__ == '__main__':
args = load_args() args = load_args()
if args.model_type=='duo': if args.model_type=='duo':
run_duo(args) # run_duo(args)
data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size)
else : else :
run(args) run(args)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment