From 3697972d1a1149b16b72d6f95ad7d5244484dd20 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Wed, 12 Mar 2025 14:39:21 +0100
Subject: [PATCH] dataset duo

---
 dataset/dataset.py                | 162 +++++++++++++++++++++++++++++-
 image_processing/build_dataset.py |  29 +++++-
 main.py                           |   4 +-
 3 files changed, 186 insertions(+), 9 deletions(-)

diff --git a/dataset/dataset.py b/dataset/dataset.py
index 446d281..e5a1a7b 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -1,8 +1,15 @@
 import torch
 import torchvision
-from torch.utils.data import DataLoader, random_split
 import torchvision.transforms as transforms
+import torch.utils.data as data
+from PIL import Image
+import os
+import os.path
+from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
+from pathlib import Path
+from collections import OrderedDict
 
+IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
 
 class Threshold_noise:
     """Remove intensities under given threshold"""
@@ -52,7 +59,7 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
     train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
     val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
 
-    data_loader_train = DataLoader(
+    data_loader_train = data.DataLoader(
         dataset=train_dataset,
         batch_size=batch_size,
         shuffle=shuffle,
@@ -61,7 +68,7 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
         pin_memory=False,
     )
 
-    data_loader_test = DataLoader(
+    data_loader_test = data.DataLoader(
         dataset=val_dataset,
         batch_size=batch_size,
         shuffle=shuffle,
@@ -70,4 +77,151 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
         pin_memory=False,
     )
 
-    return data_loader_train, data_loader_test
\ No newline at end of file
+    return data_loader_train, data_loader_test
+
+
+
+
+def default_loader(path):
+    return Image.open(path).convert('RGB')
+
+def remove_aer_ana(l):
+    l = l.map(lambda x : x.split('_')[0])
+    return list(OrderedDict.fromkeys(l))
+
+def make_dataset_custom(
+    directory: Union[str, Path],
+    class_to_idx: Optional[Dict[str, int]] = None,
+    extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS,
+    is_valid_file: Optional[Callable[[str], bool]] = None,
+    allow_empty: bool = False,
+) -> List[Tuple[str, str, int]]:
+    """Generates a list of samples of a form (path_to_sample, class).
+
+    See :class:`DatasetFolder` for details.
+
+    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
+    by default.
+    """
+    directory = os.path.expanduser(directory)
+
+    if class_to_idx is None:
+        _, class_to_idx = torchvision.datasets.folder.find_classes(directory)
+    elif not class_to_idx:
+        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
+
+    both_none = extensions is None and is_valid_file is None
+    both_something = extensions is not None and is_valid_file is not None
+    if both_none or both_something:
+        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
+
+    if extensions is not None:
+
+        def is_valid_file(x: str) -> bool:
+            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]
+
+    is_valid_file = cast(Callable[[str], bool], is_valid_file)
+
+    instances = []
+    available_classes = set()
+    for target_class in sorted(class_to_idx.keys()):
+        class_index = class_to_idx[target_class]
+        target_dir = os.path.join(directory, target_class)
+        if not os.path.isdir(target_dir):
+            continue
+        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
+            fnames_base = remove_aer_ana(fnames)
+            for fname in sorted(fnames_base):
+                fname_ana = fname+'_ANA.png'
+                fname_aer = fname + '_AER.png'
+                path_ana = os.path.join(root, fname_ana)
+                path_aer = os.path.join(root, fname_aer)
+                if is_valid_file(path_ana) and is_valid_file(path_aer):
+                    item = path_aer, path_ana, class_index
+                    instances.append(item)
+
+                    if target_class not in available_classes:
+                        available_classes.add(target_class)
+
+    empty_classes = set(class_to_idx.keys()) - available_classes
+    if empty_classes and not allow_empty:
+        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
+        if extensions is not None:
+            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
+        raise FileNotFoundError(msg)
+
+    return instances
+
+
+class ImageFolderDuo(data.Dataset):
+    def __init__(self, root, transform=None, target_transform=None,
+                 flist_reader=make_dataset_custom, loader=default_loader):
+        self.root = root
+        self.imlist = flist_reader(root)
+        self.transform = transform
+        self.target_transform = target_transform
+        self.loader = loader
+        self.classes = torchvision.datasets.folder.find_classes(root)
+
+    def __getitem__(self, index):
+        impathAER, impathANA, target = self.imlist[index]
+        imgAER = self.loader(os.path.join(self.root, impathAER))
+        imgANA = self.loader(os.path.join(self.root, impathANA))
+        if self.transform is not None:
+            imgAER = self.transform(imgAER)
+            imgANA = self.transform(imgANA)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return imgAER, imgANA, target
+
+    def __len__(self):
+        return len(self.imlist)
+
+def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
+    train_transform = transforms.Compose(
+        [transforms.Grayscale(num_output_channels=1),
+         transforms.ToTensor(),
+         transforms.Resize((224, 224)),
+         Threshold_noise(noise_threshold),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default train transform')
+
+    val_transform = transforms.Compose(
+        [transforms.Grayscale(num_output_channels=1),
+         transforms.ToTensor(),
+         transforms.Resize((224, 224)),
+         Threshold_noise(noise_threshold),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default val transform')
+    train_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=train_transform)
+    val_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=val_transform)
+    generator1 = torch.Generator().manual_seed(42)
+    indices = torch.randperm(len(train_dataset), generator=generator1)
+    val_size = len(train_dataset) // 5
+    train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
+    val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
+
+    data_loader_train = data.DataLoader(
+        dataset=train_dataset,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        num_workers=0,
+        collate_fn=None,
+        pin_memory=False,
+    )
+
+    data_loader_test = data.DataLoader(
+        dataset=val_dataset,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        num_workers=0,
+        collate_fn=None,
+        pin_memory=False,
+    )
+
+    return data_loader_train, data_loader_test
+
+
diff --git a/image_processing/build_dataset.py b/image_processing/build_dataset.py
index 1123acd..49a5153 100644
--- a/image_processing/build_dataset.py
+++ b/image_processing/build_dataset.py
@@ -13,7 +13,7 @@ find . -name '*.mzML' -exec cp -prv '{}' '/home/leo/PycharmProjects/pseudo_image
 copy des mzml depuis lecteur
 """
 
-def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx'):
+def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx',suffix='-d200'):
     """
     Extract and organise labels from raw excel file
     :param path: excel path
@@ -38,7 +38,7 @@ def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibi
         l = split_before_number(s)
         species = l[0]
         nb = l[1]
-        return '{}-{}-{}-d200.mzML'.format(species,nb,analyse)
+        return '{}-{}-{}{}.mzML'.format(species,nb,analyse,suffix)
 
     df['path_ana'] = df['sample_name'].map(lambda x: create_fname(x,analyse='ANA'))
     df['path_aer'] = df['sample_name'].map(lambda x: create_fname(x, analyse='AER'))
@@ -51,7 +51,30 @@ def create_dataset():
     Create images from raw .mzML files and sort it in their corresponding class directory
     :return: None
     """
-    label = create_antibio_dataset()
+    label = create_antibio_dataset(suffix='-d200')
+    for path in glob.glob("../data/raw_data/**.mzML"):
+        print(path)
+        species = None
+        if path.split("/")[-1] in label['path_ana'].values:
+            species = label[label['path_ana'] == path.split("/")[-1]]['species'].values[0]
+            name = label[label['path_ana'] == path.split("/")[-1]]['sample_name'].values[0]
+            analyse = 'ANA'
+        elif path.split("/")[-1] in label['path_aer'].values:
+            species = label[label['path_aer'] == path.split("/")[-1]]['species'].values[0]
+            name = label[label['path_aer'] == path.split("/")[-1]]['sample_name'].values[0]
+            analyse = 'AER'
+        if species is not None:
+            directory_path_png = '../data/processed_data/png_image/{}'.format(species)
+            directory_path_npy = '../data/processed_data/npy_image/{}'.format(species)
+            if not os.path.isdir(directory_path_png):
+                os.makedirs(directory_path_png)
+            if not os.path.isdir(directory_path_npy):
+                os.makedirs(directory_path_npy)
+            mat = build_image_ms1(path, 1)
+            mpimg.imsave(directory_path_png + "/" + name + '_' + analyse + '.png', mat)
+            np.save(directory_path_npy + "/" + name + '_' + analyse + '.npy', mat)
+
+    label = create_antibio_dataset(suffix='_100vW_100SPD')
     for path in glob.glob("../data/raw_data/**.mzML"):
         print(path)
         species = None
diff --git a/main.py b/main.py
index af15fa9..13946a3 100644
--- a/main.py
+++ b/main.py
@@ -89,10 +89,10 @@ def run(args):
     plt.plot(train_acc)
     plt.plot(train_acc)
     plt.show()
-    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model))
+    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
 
     load_model(model, args.save_path)
-    make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model))
+    make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
 
 
 def make_prediction(model, data, f_name):
-- 
GitLab