From 3697972d1a1149b16b72d6f95ad7d5244484dd20 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Wed, 12 Mar 2025 14:39:21 +0100 Subject: [PATCH] dataset duo --- dataset/dataset.py | 162 +++++++++++++++++++++++++++++- image_processing/build_dataset.py | 29 +++++- main.py | 4 +- 3 files changed, 186 insertions(+), 9 deletions(-) diff --git a/dataset/dataset.py b/dataset/dataset.py index 446d281..e5a1a7b 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -1,8 +1,15 @@ import torch import torchvision -from torch.utils.data import DataLoader, random_split import torchvision.transforms as transforms +import torch.utils.data as data +from PIL import Image +import os +import os.path +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from pathlib import Path +from collections import OrderedDict +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") class Threshold_noise: """Remove intensities under given threshold""" @@ -52,7 +59,7 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size]) val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:]) - data_loader_train = DataLoader( + data_loader_train = data.DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=shuffle, @@ -61,7 +68,7 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): pin_memory=False, ) - data_loader_test = DataLoader( + data_loader_test = data.DataLoader( dataset=val_dataset, batch_size=batch_size, shuffle=shuffle, @@ -70,4 +77,151 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): pin_memory=False, ) - return data_loader_train, data_loader_test \ No newline at end of file + return data_loader_train, data_loader_test + + + + +def default_loader(path): + return Image.open(path).convert('RGB') + +def remove_aer_ana(l): + l = l.map(lambda x : x.split('_')[0]) + return list(OrderedDict.fromkeys(l)) + +def make_dataset_custom( + directory: Union[str, Path], + class_to_idx: Optional[Dict[str, int]] = None, + extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, +) -> List[Tuple[str, str, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + See :class:`DatasetFolder` for details. + + Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function + by default. + """ + directory = os.path.expanduser(directory) + + if class_to_idx is None: + _, class_to_idx = torchvision.datasets.folder.find_classes(directory) + elif not class_to_idx: + raise ValueError("'class_to_index' must have at least one entry to collect any samples.") + + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + + if extensions is not None: + + def is_valid_file(x: str) -> bool: + return has_file_allowed_extension(x, extensions) # type: ignore[arg-type] + + is_valid_file = cast(Callable[[str], bool], is_valid_file) + + instances = [] + available_classes = set() + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + fnames_base = remove_aer_ana(fnames) + for fname in sorted(fnames_base): + fname_ana = fname+'_ANA.png' + fname_aer = fname + '_AER.png' + path_ana = os.path.join(root, fname_ana) + path_aer = os.path.join(root, fname_aer) + if is_valid_file(path_ana) and is_valid_file(path_aer): + item = path_aer, path_ana, class_index + instances.append(item) + + if target_class not in available_classes: + available_classes.add(target_class) + + empty_classes = set(class_to_idx.keys()) - available_classes + if empty_classes and not allow_empty: + msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + if extensions is not None: + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" + raise FileNotFoundError(msg) + + return instances + + +class ImageFolderDuo(data.Dataset): + def __init__(self, root, transform=None, target_transform=None, + flist_reader=make_dataset_custom, loader=default_loader): + self.root = root + self.imlist = flist_reader(root) + self.transform = transform + self.target_transform = target_transform + self.loader = loader + self.classes = torchvision.datasets.folder.find_classes(root) + + def __getitem__(self, index): + impathAER, impathANA, target = self.imlist[index] + imgAER = self.loader(os.path.join(self.root, impathAER)) + imgANA = self.loader(os.path.join(self.root, impathANA)) + if self.transform is not None: + imgAER = self.transform(imgAER) + imgANA = self.transform(imgANA) + if self.target_transform is not None: + target = self.target_transform(target) + + return imgAER, imgANA, target + + def __len__(self): + return len(self.imlist) + +def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0): + train_transform = transforms.Compose( + [transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Resize((224, 224)), + Threshold_noise(noise_threshold), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + print('Default train transform') + + val_transform = transforms.Compose( + [transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Resize((224, 224)), + Threshold_noise(noise_threshold), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + print('Default val transform') + train_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=train_transform) + val_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=val_transform) + generator1 = torch.Generator().manual_seed(42) + indices = torch.randperm(len(train_dataset), generator=generator1) + val_size = len(train_dataset) // 5 + train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size]) + val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:]) + + data_loader_train = data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + data_loader_test = data.DataLoader( + dataset=val_dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + return data_loader_train, data_loader_test + + diff --git a/image_processing/build_dataset.py b/image_processing/build_dataset.py index 1123acd..49a5153 100644 --- a/image_processing/build_dataset.py +++ b/image_processing/build_dataset.py @@ -13,7 +13,7 @@ find . -name '*.mzML' -exec cp -prv '{}' '/home/leo/PycharmProjects/pseudo_image copy des mzml depuis lecteur """ -def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx'): +def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx',suffix='-d200'): """ Extract and organise labels from raw excel file :param path: excel path @@ -38,7 +38,7 @@ def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibi l = split_before_number(s) species = l[0] nb = l[1] - return '{}-{}-{}-d200.mzML'.format(species,nb,analyse) + return '{}-{}-{}{}.mzML'.format(species,nb,analyse,suffix) df['path_ana'] = df['sample_name'].map(lambda x: create_fname(x,analyse='ANA')) df['path_aer'] = df['sample_name'].map(lambda x: create_fname(x, analyse='AER')) @@ -51,7 +51,30 @@ def create_dataset(): Create images from raw .mzML files and sort it in their corresponding class directory :return: None """ - label = create_antibio_dataset() + label = create_antibio_dataset(suffix='-d200') + for path in glob.glob("../data/raw_data/**.mzML"): + print(path) + species = None + if path.split("/")[-1] in label['path_ana'].values: + species = label[label['path_ana'] == path.split("/")[-1]]['species'].values[0] + name = label[label['path_ana'] == path.split("/")[-1]]['sample_name'].values[0] + analyse = 'ANA' + elif path.split("/")[-1] in label['path_aer'].values: + species = label[label['path_aer'] == path.split("/")[-1]]['species'].values[0] + name = label[label['path_aer'] == path.split("/")[-1]]['sample_name'].values[0] + analyse = 'AER' + if species is not None: + directory_path_png = '../data/processed_data/png_image/{}'.format(species) + directory_path_npy = '../data/processed_data/npy_image/{}'.format(species) + if not os.path.isdir(directory_path_png): + os.makedirs(directory_path_png) + if not os.path.isdir(directory_path_npy): + os.makedirs(directory_path_npy) + mat = build_image_ms1(path, 1) + mpimg.imsave(directory_path_png + "/" + name + '_' + analyse + '.png', mat) + np.save(directory_path_npy + "/" + name + '_' + analyse + '.npy', mat) + + label = create_antibio_dataset(suffix='_100vW_100SPD') for path in glob.glob("../data/raw_data/**.mzML"): print(path) species = None diff --git a/main.py b/main.py index af15fa9..13946a3 100644 --- a/main.py +++ b/main.py @@ -89,10 +89,10 @@ def run(args): plt.plot(train_acc) plt.plot(train_acc) plt.show() - plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model)) + plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) load_model(model, args.save_path) - make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model)) + make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) def make_prediction(model, data, f_name): -- GitLab