Skip to content
Snippets Groups Projects
Commit 3697972d authored by Schneider Leo's avatar Schneider Leo
Browse files

dataset duo

parent 014eb92e
No related branches found
No related tags found
No related merge requests found
import torch
import torchvision
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from pathlib import Path
from collections import OrderedDict
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
class Threshold_noise:
"""Remove intensities under given threshold"""
......@@ -52,7 +59,7 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
data_loader_train = DataLoader(
data_loader_train = data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=shuffle,
......@@ -61,7 +68,7 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
pin_memory=False,
)
data_loader_test = DataLoader(
data_loader_test = data.DataLoader(
dataset=val_dataset,
batch_size=batch_size,
shuffle=shuffle,
......@@ -70,4 +77,151 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
pin_memory=False,
)
return data_loader_train, data_loader_test
\ No newline at end of file
return data_loader_train, data_loader_test
def default_loader(path):
return Image.open(path).convert('RGB')
def remove_aer_ana(l):
l = l.map(lambda x : x.split('_')[0])
return list(OrderedDict.fromkeys(l))
def make_dataset_custom(
directory: Union[str, Path],
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS,
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
See :class:`DatasetFolder` for details.
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = torchvision.datasets.folder.find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
fnames_base = remove_aer_ana(fnames)
for fname in sorted(fnames_base):
fname_ana = fname+'_ANA.png'
fname_aer = fname + '_AER.png'
path_ana = os.path.join(root, fname_ana)
path_aer = os.path.join(root, fname_aer)
if is_valid_file(path_ana) and is_valid_file(path_aer):
item = path_aer, path_ana, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes and not allow_empty:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
class ImageFolderDuo(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=default_loader):
self.root = root
self.imlist = flist_reader(root)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.classes = torchvision.datasets.folder.find_classes(root)
def __getitem__(self, index):
impathAER, impathANA, target = self.imlist[index]
imgAER = self.loader(os.path.join(self.root, impathAER))
imgANA = self.loader(os.path.join(self.root, impathANA))
if self.transform is not None:
imgAER = self.transform(imgAER)
imgANA = self.transform(imgANA)
if self.target_transform is not None:
target = self.target_transform(target)
return imgAER, imgANA, target
def __len__(self):
return len(self.imlist)
def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
train_transform = transforms.Compose(
[transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
print('Default train transform')
val_transform = transforms.Compose(
[transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
print('Default val transform')
train_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=train_transform)
val_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=val_transform)
generator1 = torch.Generator().manual_seed(42)
indices = torch.randperm(len(train_dataset), generator=generator1)
val_size = len(train_dataset) // 5
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
data_loader_train = data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=0,
collate_fn=None,
pin_memory=False,
)
data_loader_test = data.DataLoader(
dataset=val_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=0,
collate_fn=None,
pin_memory=False,
)
return data_loader_train, data_loader_test
......@@ -13,7 +13,7 @@ find . -name '*.mzML' -exec cp -prv '{}' '/home/leo/PycharmProjects/pseudo_image
copy des mzml depuis lecteur
"""
def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx'):
def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx',suffix='-d200'):
"""
Extract and organise labels from raw excel file
:param path: excel path
......@@ -38,7 +38,7 @@ def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibi
l = split_before_number(s)
species = l[0]
nb = l[1]
return '{}-{}-{}-d200.mzML'.format(species,nb,analyse)
return '{}-{}-{}{}.mzML'.format(species,nb,analyse,suffix)
df['path_ana'] = df['sample_name'].map(lambda x: create_fname(x,analyse='ANA'))
df['path_aer'] = df['sample_name'].map(lambda x: create_fname(x, analyse='AER'))
......@@ -51,7 +51,30 @@ def create_dataset():
Create images from raw .mzML files and sort it in their corresponding class directory
:return: None
"""
label = create_antibio_dataset()
label = create_antibio_dataset(suffix='-d200')
for path in glob.glob("../data/raw_data/**.mzML"):
print(path)
species = None
if path.split("/")[-1] in label['path_ana'].values:
species = label[label['path_ana'] == path.split("/")[-1]]['species'].values[0]
name = label[label['path_ana'] == path.split("/")[-1]]['sample_name'].values[0]
analyse = 'ANA'
elif path.split("/")[-1] in label['path_aer'].values:
species = label[label['path_aer'] == path.split("/")[-1]]['species'].values[0]
name = label[label['path_aer'] == path.split("/")[-1]]['sample_name'].values[0]
analyse = 'AER'
if species is not None:
directory_path_png = '../data/processed_data/png_image/{}'.format(species)
directory_path_npy = '../data/processed_data/npy_image/{}'.format(species)
if not os.path.isdir(directory_path_png):
os.makedirs(directory_path_png)
if not os.path.isdir(directory_path_npy):
os.makedirs(directory_path_npy)
mat = build_image_ms1(path, 1)
mpimg.imsave(directory_path_png + "/" + name + '_' + analyse + '.png', mat)
np.save(directory_path_npy + "/" + name + '_' + analyse + '.npy', mat)
label = create_antibio_dataset(suffix='_100vW_100SPD')
for path in glob.glob("../data/raw_data/**.mzML"):
print(path)
species = None
......
......@@ -89,10 +89,10 @@ def run(args):
plt.plot(train_acc)
plt.plot(train_acc)
plt.show()
plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model))
plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
load_model(model, args.save_path)
make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model))
make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
def make_prediction(model, data, f_name):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment