-
Schneider Leo authored95a866fe
dataset_ref.py 10.86 KiB
import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import os
import os.path
from typing import Callable, cast, Dict, List, Optional, Tuple, Union
from pathlib import Path
from collections import OrderedDict
from torch.utils.data import WeightedRandomSampler
IMG_EXTENSIONS = ".npy"
class Threshold_noise:
"""Remove intensities under given threshold"""
def __init__(self, threshold=100.):
self.threshold = threshold
def __call__(self, x):
return torch.where((x <= self.threshold), 0.,x)
class Log_normalisation:
"""Log normalisation of intensities"""
def __init__(self, eps=1e-5):
self.epsilon = eps
def __call__(self, x):
return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon)
class Random_shift_rt:
def __init__(self, max_cycle=5):
self.max_cycle = max_cycle
def __call__(self, x):
pass
class Intensity_shift:
def __init__(self, min_ratio=0.1, max_ratio = 10):
self.min_ratio = min_ratio
self.max_ratio = max_ratio
def __call__(self, x):
intensity_ratio_map = torch.rand(x.shape())
return torch.mul(x,intensity_ratio_map)
def default_loader(path):
return Image.open(path).convert('RGB')
def npy_loader(path):
sample = torch.from_numpy(np.load(path))
sample = sample.unsqueeze(0)
return sample
def remove_aer_ana(l):
l = map(lambda x : x.split('_')[0],l)
return list(OrderedDict.fromkeys(l))
def make_dataset_custom(
directory: Union[str, Path],
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS,
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
See :class:`DatasetFolder` for details.
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = torchvision.datasets.folder.find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return torchvision.datasets.folder.has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
fnames_base = remove_aer_ana(fnames)
for fname in sorted(fnames_base):
fname_ana = fname+'_ANA.npy'
fname_aer = fname + '_AER.npy'
path_ana = os.path.join(root, fname_ana)
path_aer = os.path.join(root, fname_aer)
if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer):
item = path_aer, path_ana, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes and not allow_empty:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
class ImageFolderDuo(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, positive_prop=None, ref_transform=None, base_dir=None):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.ref_transform = ref_transform
self.loader = loader
self.classes = torchvision.datasets.folder.find_classes(root)[0]
if base_dir is not None :
self.ref_dir = os.path.join(base_dir,ref_dir)
self.imlist = flist_reader(os.path.join(base_dir,root))
else :
self.ref_dir = ref_dir
self.imlist = flist_reader(root)
self.positive_prop = positive_prop
def __getitem__(self, index):
impathAER, impathANA, target = self.imlist[index]
imgAER = self.loader(impathAER)
imgANA = self.loader(impathANA)
if self.positive_prop is not None:
i = random.randint(0,100)
if i < self.positive_prop:
label_ref = target
else :
l = np.arange(0,len(self.classes),dtype=int).tolist()
l.remove(target)
label_ref = np.random.choice(l)
else :
label_ref = np.random.randint(0,len(self.classes)-1)
class_ref = self.classes[label_ref]
path_ref = self.ref_dir +'/'+ class_ref + '.npy'
img_ref = self.loader(path_ref)
if self.transform is not None:
imgAER = self.transform(imgAER)
imgANA = self.transform(imgANA)
img_ref = self.ref_transform(img_ref)
contrastive_target = 0 if target == label_ref else 1
return imgAER, imgANA, img_ref, contrastive_target
def __len__(self):
return len(self.imlist)
def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None, base_dir=None):
train_transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
print('Default train transform')
val_transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
print('Default val transform')
ref_transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(0),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)
])
print('Default val transform')
train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform, base_dir=base_dir)
val_dataset = ImageFolderDuo_Batched(root=base_dir_val, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform, base_dir=base_dir)
if base_dir_test is not None :
test_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir=ref_dir,
ref_transform=ref_transform, base_dir=base_dir)
if sampler =='balanced' :
y_train_label = np.array([i for (_,_,i)in train_dataset.imlist])
class_sample_count = np.array([len(np.where(y_train_label == t)[0]) for t in np.unique(y_train_label)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in y_train_label])
samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
data_loader_train = data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=0,
collate_fn=None,
pin_memory=False,
)
else :
data_loader_train = data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=0,
collate_fn=None,
pin_memory=False,
)
data_loader_val = data.DataLoader(
dataset=val_dataset,
batch_size=1,
shuffle=shuffle,
num_workers=0,
collate_fn=None,
pin_memory=False,
)
if base_dir_test is not None :
data_loader_test = data.DataLoader(
dataset=test_dataset,
batch_size=1,
shuffle=shuffle,
num_workers=0,
collate_fn=None,
pin_memory=False,
)
else :
data_loader_test = None
return data_loader_train, data_loader_val, data_loader_test
class ImageFolderDuo_Batched(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, ref_transform=None, base_dir=None):
self.root = root
if base_dir is not None:
self.ref_dir = os.path.join(base_dir, ref_dir)
self.imlist = flist_reader(os.path.join(base_dir, root))
else:
self.ref_dir = ref_dir
self.imlist = flist_reader(root)
self.transform = transform
self.ref_transform = ref_transform
self.target_transform = target_transform
self.loader = loader
self.classes = torchvision.datasets.folder.find_classes(root)[0]
def __getitem__(self, index):
impathAER, impathANA, target = self.imlist[index]
imgAER = self.loader(impathAER)
imgANA = self.loader(impathANA)
img_refs = []
label_refs = []
for ind_ref in range(len(self.classes)):
class_ref = self.classes[ind_ref]
target_ref = 0 if target == ind_ref else 1
path_ref = self.ref_dir +'/'+ class_ref + '.npy'
img_ref = self.loader(path_ref)
if self.transform is not None:
img_ref = self.ref_transform(img_ref)
img_refs.append(img_ref)
label_refs.append(target_ref)
if self.transform is not None:
imgAER = self.transform(imgAER)
imgANA = self.transform(imgANA)
batched_im_ref = torch.concat(img_refs,dim=0)
batched_label = torch.tensor(label_refs)
batched_imgAER = imgAER.repeat(len(self.classes),1,1)
batched_imgANA = imgANA.repeat(len(self.classes),1,1)
return batched_imgAER, batched_imgANA, batched_im_ref, batched_label
def __len__(self):
return len(self.imlist)