From b7bfdcab9c75a6588c67f9e21570579a9922d2de Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Wed, 12 Mar 2025 11:17:48 +0100
Subject: [PATCH] model cuda loading

---
 config/config.py   |  1 +
 dataset/dataset.py | 35 +++++++++++++++++------------------
 2 files changed, 18 insertions(+), 18 deletions(-)

diff --git a/config/config.py b/config/config.py
index bad9ecb..d898b6b 100644
--- a/config/config.py
+++ b/config/config.py
@@ -7,6 +7,7 @@ def load_args():
     parser.add_argument('--epoches', type=int, default=100)
     parser.add_argument('--save_inter', type=int, default=50)
     parser.add_argument('--eval_inter', type=int, default=1)
+    parser.add_argument('--noise_threshold', type=int, default=0)
     parser.add_argument('--lr', type=float, default=0.001)
     parser.add_argument('--batch_size', type=int, default=64)
     parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training')
diff --git a/dataset/dataset.py b/dataset/dataset.py
index ec59fea..6f807d0 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -26,25 +26,24 @@ class Random_shift_rt():
     pass
 
 
-def load_data(base_dir, batch_size, shuffle=True, transform=None):
-    if transform is None :
-        train_transform = transforms.Compose(
-            [transforms.Grayscale(num_output_channels=1),
-             transforms.ToTensor(),
-             transforms.Resize((224,224)),
-             Threshold_noise(500),
-             Log_normalisation(),
-             transforms.Normalize(0.5, 0.5)])
-        print('Default train transform')
+def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
+    train_transform = transforms.Compose(
+        [transforms.Grayscale(num_output_channels=1),
+         transforms.ToTensor(),
+         transforms.Resize((224,224)),
+         Threshold_noise(noise_threshold),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default train transform')
 
-        val_transform = transforms.Compose(
-            [transforms.Grayscale(num_output_channels=1),
-             transforms.ToTensor(),
-             transforms.Resize((224,224)),
-             Threshold_noise(500),
-             Log_normalisation(),
-             transforms.Normalize(0.5, 0.5)])
-        print('Default val transform')
+    val_transform = transforms.Compose(
+        [transforms.Grayscale(num_output_channels=1),
+         transforms.ToTensor(),
+         transforms.Resize((224,224)),
+         Threshold_noise(noise_threshold),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default val transform')
     train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform)
     val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform)
     generator1 = torch.Generator().manual_seed(42)
-- 
GitLab