From f01e80b6ae48571444db0b95f5b00f9fe55bd0ac Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Fri, 18 Apr 2025 14:44:46 +0200
Subject: [PATCH] fix : dataset ray base

---
 dataset/dataset.py | 13 +++++++------
 main_ray.py        |  4 ++--
 2 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/dataset/dataset.py b/dataset/dataset.py
index 9f84e63..5e05428 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -34,6 +34,11 @@ class Random_shift_rt():
     pass
 
 
+def npy_loader(path):
+    sample = torch.from_numpy(np.load(path))
+    sample = sample.unsqueeze(0)
+    return sample
+
 def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
     train_transform = transforms.Compose(
         [transforms.Grayscale(num_output_channels=1),
@@ -52,8 +57,8 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
          Log_normalisation(),
          transforms.Normalize(0.5, 0.5)])
     print('Default val transform')
-    train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform)
-    val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform)
+    train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform,loader=npy_loader)
+    val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform, loader=npy_loader)
 
     train_dataset, _ = train_test_split(train_dataset, test_size=None, train_size=None, random_state=42, shuffle=True,
                                              stratify=True)
@@ -86,10 +91,6 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
 def default_loader(path):
     return Image.open(path).convert('RGB')
 
-def npy_loader(path):
-    sample = torch.from_numpy(np.load(path))
-    sample = sample.unsqueeze(0)
-    return sample
 
 def remove_aer_ana(l):
     l = map(lambda x : x.split('_')[0],l)
diff --git a/main_ray.py b/main_ray.py
index 6734eca..2a32229 100644
--- a/main_ray.py
+++ b/main_ray.py
@@ -26,7 +26,7 @@ def train_model(config,args):
                                                   )
 
     # load model
-    model = Classification_model_duo(model=args.model, n_class=len(data_train.dataset.classes))
+    model = Classification_model_duo(model=args.model, n_class=len(data_train.dataset.dataset.classes))
 
     # move parameters to GPU
     model.float()
@@ -144,7 +144,7 @@ def test_model(best_result, args):
                                          noise_threshold=best_result.config['noise'])
 
     # load model
-    model = Classification_model_duo(model=args.model, n_class=len(data_test.dataset.classes))
+    model = Classification_model_duo(model=args.model, n_class=len(data_test.dataset.dataset.classes))
     model.float()
     # load weight
     checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
-- 
GitLab