From b7bfdcab9c75a6588c67f9e21570579a9922d2de Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Wed, 12 Mar 2025 11:17:48 +0100 Subject: [PATCH] model cuda loading --- config/config.py | 1 + dataset/dataset.py | 35 +++++++++++++++++------------------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/config/config.py b/config/config.py index bad9ecb..d898b6b 100644 --- a/config/config.py +++ b/config/config.py @@ -7,6 +7,7 @@ def load_args(): parser.add_argument('--epoches', type=int, default=100) parser.add_argument('--save_inter', type=int, default=50) parser.add_argument('--eval_inter', type=int, default=1) + parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training') diff --git a/dataset/dataset.py b/dataset/dataset.py index ec59fea..6f807d0 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -26,25 +26,24 @@ class Random_shift_rt(): pass -def load_data(base_dir, batch_size, shuffle=True, transform=None): - if transform is None : - train_transform = transforms.Compose( - [transforms.Grayscale(num_output_channels=1), - transforms.ToTensor(), - transforms.Resize((224,224)), - Threshold_noise(500), - Log_normalisation(), - transforms.Normalize(0.5, 0.5)]) - print('Default train transform') +def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): + train_transform = transforms.Compose( + [transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Resize((224,224)), + Threshold_noise(noise_threshold), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + print('Default train transform') - val_transform = transforms.Compose( - [transforms.Grayscale(num_output_channels=1), - transforms.ToTensor(), - transforms.Resize((224,224)), - Threshold_noise(500), - Log_normalisation(), - transforms.Normalize(0.5, 0.5)]) - print('Default val transform') + val_transform = transforms.Compose( + [transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Resize((224,224)), + Threshold_noise(noise_threshold), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + print('Default val transform') train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform) val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform) generator1 = torch.Generator().manual_seed(42) -- GitLab