Skip to content
Snippets Groups Projects
dataset_ref.py 8.19 KiB
import random

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import os
import os.path
from typing import Callable, cast, Dict, List, Optional, Tuple, Union
from pathlib import Path
from collections import OrderedDict

IMG_EXTENSIONS = ".npy"

class Threshold_noise:
    """Remove intensities under given threshold"""

    def __init__(self, threshold=100.):
        self.threshold = threshold

    def __call__(self, x):
        return torch.where((x <= self.threshold), 0.,x)

class Log_normalisation:
    """Log normalisation of intensities"""

    def __init__(self, eps=1e-5):
        self.epsilon = eps

    def __call__(self, x):
        return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon)

class Random_shift_rt():
    pass


def default_loader(path):
    return Image.open(path).convert('RGB')

def npy_loader(path):
    sample = torch.from_numpy(np.load(path))
    sample = sample.unsqueeze(0)
    return sample

def remove_aer_ana(l):
    l = map(lambda x : x.split('_')[0],l)
    return list(OrderedDict.fromkeys(l))

def make_dataset_custom(
    directory: Union[str, Path],
    class_to_idx: Optional[Dict[str, int]] = None,
    extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS,
    is_valid_file: Optional[Callable[[str], bool]] = None,
    allow_empty: bool = False,
) -> List[Tuple[str, str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).

    See :class:`DatasetFolder` for details.

    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
    """
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = torchvision.datasets.folder.find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:

        def is_valid_file(x: str) -> bool:
            return torchvision.datasets.folder.has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            fnames_base = remove_aer_ana(fnames)
            for fname in sorted(fnames_base):
                fname_ana = fname+'_ANA.npy'
                fname_aer = fname + '_AER.npy'
                path_ana = os.path.join(root, fname_ana)
                path_aer = os.path.join(root, fname_aer)
                if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer):
                    item = path_aer, path_ana, class_index
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes and not allow_empty:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances


class ImageFolderDuo(data.Dataset):
    def __init__(self, root, transform=None, target_transform=None,
                 flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, positive_prop=None):
        self.root = root
        self.imlist = flist_reader(root)
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.classes = torchvision.datasets.folder.find_classes(root)[0]
        self.ref_dir = ref_dir
        self.positive_prop = positive_prop

    def __getitem__(self, index):
        impathAER, impathANA, target = self.imlist[index]
        imgAER = self.loader(impathAER)
        imgANA = self.loader(impathANA)
        if self.positive_prop is not None:
            i = random.randint(0,100)
            if i < self.positive_prop:
                label_ref = target
            else :
                label_ref = np.random.randint(0,len(self.classes)-1) #can be postive too (border effect)
        else :
            label_ref = np.random.randint(0,len(self.classes)-1)
        class_ref = self.classes[label_ref]
        path_ref = self.ref_dir +'/'+ class_ref + '.npy'
        img_ref = self.loader(path_ref)
        if self.transform is not None:
            imgAER = self.transform(imgAER)
            imgANA = self.transform(imgANA)
            img_ref = self.transform(img_ref)
        target = 0 if target == label_ref else 1
        return imgAER, imgANA, img_ref, target

    def __len__(self):
        return len(self.imlist)

def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None):

    train_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         Threshold_noise(noise_threshold),
         Log_normalisation(),
         transforms.Normalize(0.5, 0.5)])
    print('Default train transform')

    val_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         Threshold_noise(noise_threshold),
         Log_normalisation(),
         transforms.Normalize(0.5, 0.5)])
    print('Default val transform')

    train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop)
    val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir)

    data_loader_train = data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
        collate_fn=None,
        pin_memory=False,
    )

    data_loader_test = data.DataLoader(
        dataset=val_dataset,
        batch_size=1,
        shuffle=shuffle,
        num_workers=0,
        collate_fn=None,
        pin_memory=False,
    )

    return data_loader_train, data_loader_test


class ImageFolderDuo_Batched(data.Dataset):
    def __init__(self, root, transform=None, target_transform=None,
                 flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None):
        self.root = root
        self.imlist = flist_reader(root)
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.classes = torchvision.datasets.folder.find_classes(root)[0]
        self.ref_dir = ref_dir

    def __getitem__(self, index):
        impathAER, impathANA, target = self.imlist[index]
        imgAER = self.loader(impathAER)
        imgANA = self.loader(impathANA)
        img_refs = []
        label_refs = []
        for ind_ref in range(len(self.classes)):
            class_ref = self.classes[ind_ref]
            target_ref = 0 if target == ind_ref else 1
            path_ref = self.ref_dir +'/'+ class_ref + '.npy'
            img_ref = self.loader(path_ref)
            if self.transform is not None:
                img_ref = self.transform(img_ref)
            img_refs.append(img_ref)
            label_refs.append(target_ref)
        if self.transform is not None:
            imgAER = self.transform(imgAER)
            imgANA = self.transform(imgANA)

        batched_im_ref = torch.concat(img_refs,dim=0)
        batched_label = torch.tensor(label_refs)
        batched_imgAER = imgAER.repeat(len(self.classes),1,1)
        batched_imgANA = imgANA.repeat(len(self.classes),1,1)

        return batched_imgAER, batched_imgANA, batched_im_ref, batched_label

    def __len__(self):
        return len(self.imlist)