import random import numpy as np import torch import torchvision import torchvision.transforms as transforms import torch.utils.data as data from PIL import Image import os import os.path from typing import Callable, cast, Dict, List, Optional, Tuple, Union from pathlib import Path from collections import OrderedDict IMG_EXTENSIONS = ".npy" class Threshold_noise: """Remove intensities under given threshold""" def __init__(self, threshold=100.): self.threshold = threshold def __call__(self, x): return torch.where((x <= self.threshold), 0.,x) class Log_normalisation: """Log normalisation of intensities""" def __init__(self, eps=1e-5): self.epsilon = eps def __call__(self, x): return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon) class Random_shift_rt(): pass def default_loader(path): return Image.open(path).convert('RGB') def npy_loader(path): sample = torch.from_numpy(np.load(path)) sample = sample.unsqueeze(0) return sample def remove_aer_ana(l): l = map(lambda x : x.split('_')[0],l) return list(OrderedDict.fromkeys(l)) def make_dataset_custom( directory: Union[str, Path], class_to_idx: Optional[Dict[str, int]] = None, extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS, is_valid_file: Optional[Callable[[str], bool]] = None, allow_empty: bool = False, ) -> List[Tuple[str, str, int]]: """Generates a list of samples of a form (path_to_sample, class). See :class:`DatasetFolder` for details. Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function by default. """ directory = os.path.expanduser(directory) if class_to_idx is None: _, class_to_idx = torchvision.datasets.folder.find_classes(directory) elif not class_to_idx: raise ValueError("'class_to_index' must have at least one entry to collect any samples.") both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") if extensions is not None: def is_valid_file(x: str) -> bool: return torchvision.datasets.folder.has_file_allowed_extension(x, extensions) # type: ignore[arg-type] is_valid_file = cast(Callable[[str], bool], is_valid_file) instances = [] available_classes = set() for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] target_dir = os.path.join(directory, target_class) if not os.path.isdir(target_dir): continue for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): fnames_base = remove_aer_ana(fnames) for fname in sorted(fnames_base): fname_ana = fname+'_ANA.npy' fname_aer = fname + '_AER.npy' path_ana = os.path.join(root, fname_ana) path_aer = os.path.join(root, fname_aer) if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer): item = path_aer, path_ana, class_index instances.append(item) if target_class not in available_classes: available_classes.add(target_class) empty_classes = set(class_to_idx.keys()) - available_classes if empty_classes and not allow_empty: msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " if extensions is not None: msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" raise FileNotFoundError(msg) return instances class ImageFolderDuo(data.Dataset): def __init__(self, root, transform=None, target_transform=None, flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, positive_prop=None): self.root = root self.imlist = flist_reader(root) self.transform = transform self.target_transform = target_transform self.loader = loader self.classes = torchvision.datasets.folder.find_classes(root)[0] self.ref_dir = ref_dir self.positive_prop = positive_prop def __getitem__(self, index): impathAER, impathANA, target = self.imlist[index] imgAER = self.loader(impathAER) imgANA = self.loader(impathANA) if self.positive_prop is not None: i = random.randint(0,100) if i < self.positive_prop: label_ref = target else : label_ref = np.random.randint(0,len(self.classes)-1) #can be postive too (border effect) else : label_ref = np.random.randint(0,len(self.classes)-1) class_ref = self.classes[label_ref] path_ref = self.ref_dir +'/'+ class_ref + '.npy' img_ref = self.loader(path_ref) if self.transform is not None: imgAER = self.transform(imgAER) imgANA = self.transform(imgANA) img_ref = self.transform(img_ref) target = 0 if target == label_ref else 1 return imgAER, imgANA, img_ref, target def __len__(self): return len(self.imlist) def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None): train_transform = transforms.Compose( [transforms.Resize((224, 224)), Threshold_noise(noise_threshold), Log_normalisation(), transforms.Normalize(0.5, 0.5)]) print('Default train transform') val_transform = transforms.Compose( [transforms.Resize((224, 224)), Threshold_noise(noise_threshold), Log_normalisation(), transforms.Normalize(0.5, 0.5)]) print('Default val transform') train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop) val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir) data_loader_train = data.DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, collate_fn=None, pin_memory=False, ) data_loader_test = data.DataLoader( dataset=val_dataset, batch_size=1, shuffle=shuffle, num_workers=0, collate_fn=None, pin_memory=False, ) return data_loader_train, data_loader_test class ImageFolderDuo_Batched(data.Dataset): def __init__(self, root, transform=None, target_transform=None, flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None): self.root = root self.imlist = flist_reader(root) self.transform = transform self.target_transform = target_transform self.loader = loader self.classes = torchvision.datasets.folder.find_classes(root)[0] self.ref_dir = ref_dir def __getitem__(self, index): impathAER, impathANA, target = self.imlist[index] imgAER = self.loader(impathAER) imgANA = self.loader(impathANA) img_refs = [] label_refs = [] for ind_ref in range(len(self.classes)): class_ref = self.classes[ind_ref] target_ref = 0 if target == ind_ref else 1 path_ref = self.ref_dir +'/'+ class_ref + '.npy' img_ref = self.loader(path_ref) if self.transform is not None: img_ref = self.transform(img_ref) img_refs.append(img_ref) label_refs.append(target_ref) if self.transform is not None: imgAER = self.transform(imgAER) imgANA = self.transform(imgANA) batched_im_ref = torch.concat(img_refs,dim=0) batched_label = torch.tensor(label_refs) batched_imgAER = imgAER.repeat(len(self.classes),1,1) batched_imgANA = imgANA.repeat(len(self.classes),1,1) return batched_imgAER, batched_imgANA, batched_im_ref, batched_label def __len__(self): return len(self.imlist)