-
Schneider Leo authored2fa27c8b
dataset_ref.py 8.19 KiB
import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import os
import os.path
from typing import Callable, cast, Dict, List, Optional, Tuple, Union
from pathlib import Path
from collections import OrderedDict
IMG_EXTENSIONS = ".npy"
class Threshold_noise:
"""Remove intensities under given threshold"""
def __init__(self, threshold=100.):
self.threshold = threshold
def __call__(self, x):
return torch.where((x <= self.threshold), 0.,x)
class Log_normalisation:
"""Log normalisation of intensities"""
def __init__(self, eps=1e-5):
self.epsilon = eps
def __call__(self, x):
return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon)
class Random_shift_rt():
pass
def default_loader(path):
return Image.open(path).convert('RGB')
def npy_loader(path):
sample = torch.from_numpy(np.load(path))
sample = sample.unsqueeze(0)
return sample
def remove_aer_ana(l):
l = map(lambda x : x.split('_')[0],l)
return list(OrderedDict.fromkeys(l))
def make_dataset_custom(
directory: Union[str, Path],
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS,
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
See :class:`DatasetFolder` for details.
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = torchvision.datasets.folder.find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return torchvision.datasets.folder.has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
fnames_base = remove_aer_ana(fnames)
for fname in sorted(fnames_base):
fname_ana = fname+'_ANA.npy'
fname_aer = fname + '_AER.npy'
path_ana = os.path.join(root, fname_ana)
path_aer = os.path.join(root, fname_aer)
if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer):
item = path_aer, path_ana, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes and not allow_empty:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
class ImageFolderDuo(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, positive_prop=None):
self.root = root
self.imlist = flist_reader(root)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.classes = torchvision.datasets.folder.find_classes(root)[0]
self.ref_dir = ref_dir
self.positive_prop = positive_prop
def __getitem__(self, index):
impathAER, impathANA, target = self.imlist[index]
imgAER = self.loader(impathAER)
imgANA = self.loader(impathANA)
if self.positive_prop is not None:
i = random.randint(0,100)
if i < self.positive_prop:
label_ref = target
else :
label_ref = np.random.randint(0,len(self.classes)-1) #can be postive too (border effect)
else :
label_ref = np.random.randint(0,len(self.classes)-1)
class_ref = self.classes[label_ref]
path_ref = self.ref_dir +'/'+ class_ref + '.npy'
img_ref = self.loader(path_ref)
if self.transform is not None:
imgAER = self.transform(imgAER)
imgANA = self.transform(imgANA)
img_ref = self.transform(img_ref)
target = 0 if target == label_ref else 1
return imgAER, imgANA, img_ref, target
def __len__(self):
return len(self.imlist)
def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None):
train_transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
print('Default train transform')
val_transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
print('Default val transform')
train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop)
val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir)
data_loader_train = data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=0,
collate_fn=None,
pin_memory=False,
)
data_loader_test = data.DataLoader(
dataset=val_dataset,
batch_size=1,
shuffle=shuffle,
num_workers=0,
collate_fn=None,
pin_memory=False,
)
return data_loader_train, data_loader_test
class ImageFolderDuo_Batched(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None):
self.root = root
self.imlist = flist_reader(root)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.classes = torchvision.datasets.folder.find_classes(root)[0]
self.ref_dir = ref_dir
def __getitem__(self, index):
impathAER, impathANA, target = self.imlist[index]
imgAER = self.loader(impathAER)
imgANA = self.loader(impathANA)
img_refs = []
label_refs = []
for ind_ref in range(len(self.classes)):
class_ref = self.classes[ind_ref]
target_ref = 0 if target == ind_ref else 1
path_ref = self.ref_dir +'/'+ class_ref + '.npy'
img_ref = self.loader(path_ref)
if self.transform is not None:
img_ref = self.transform(img_ref)
img_refs.append(img_ref)
label_refs.append(target_ref)
if self.transform is not None:
imgAER = self.transform(imgAER)
imgANA = self.transform(imgANA)
batched_im_ref = torch.concat(img_refs,dim=0)
batched_label = torch.tensor(label_refs)
batched_imgAER = imgAER.repeat(len(self.classes),1,1)
batched_imgANA = imgANA.repeat(len(self.classes),1,1)
return batched_imgAER, batched_imgANA, batched_im_ref, batched_label
def __len__(self):
return len(self.imlist)