From 24c4752193837561b4118fd1d95afbdfee6bb8dc Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Mon, 7 Apr 2025 17:08:56 +0200
Subject: [PATCH] add weigthed sampler for training

---
 image_ref/config.py      |  1 +
 image_ref/dataset_ref.py | 40 +++++++++++++++++++++++++++++++---------
 image_ref/main.py        |  2 +-
 3 files changed, 33 insertions(+), 10 deletions(-)

diff --git a/image_ref/config.py b/image_ref/config.py
index aac9449d..2d713c20 100644
--- a/image_ref/config.py
+++ b/image_ref/config.py
@@ -12,6 +12,7 @@ def load_args_contrastive():
     parser.add_argument('--batch_size', type=int, default=64)
     parser.add_argument('--positive_prop', type=int, default=30)
     parser.add_argument('--model', type=str, default='ResNet18')
+    parser.add_argument('--sampler', type=str, default=None)
     parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive')
     parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive')
     parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py
index f2bd8fcc..e5dc9871 100644
--- a/image_ref/dataset_ref.py
+++ b/image_ref/dataset_ref.py
@@ -12,6 +12,8 @@ from typing import Callable, cast, Dict, List, Optional, Tuple, Union
 from pathlib import Path
 from collections import OrderedDict
 
+from torch.utils.data import WeightedRandomSampler
+
 IMG_EXTENSIONS = ".npy"
 
 class Threshold_noise:
@@ -152,7 +154,9 @@ class ImageFolderDuo(data.Dataset):
     def __len__(self):
         return len(self.imlist)
 
-def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None):
+def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None):
+
+
 
     train_transform = transforms.Compose(
         [transforms.Resize((224, 224)),
@@ -180,14 +184,32 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
     train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform)
     val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform)
 
-    data_loader_train = data.DataLoader(
-        dataset=train_dataset,
-        batch_size=batch_size,
-        shuffle=shuffle,
-        num_workers=0,
-        collate_fn=None,
-        pin_memory=False,
-    )
+    if sampler =='weighted' :
+        y_train_label = np.array([i for (_,_,i)in train_dataset.imlist])
+        class_sample_count = np.array([len(np.where(y_train_label == t)[0]) for t in np.unique(y_train_label)])
+        weight = 1. / class_sample_count
+        samples_weight = np.array([weight[t] for t in y_train_label])
+
+        samples_weight = torch.from_numpy(samples_weight)
+        sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
+
+        data_loader_train = data.DataLoader(
+            dataset=train_dataset,
+            batch_size=batch_size,
+            sampler=sampler,
+            num_workers=0,
+            collate_fn=None,
+            pin_memory=False,
+        )
+    else :
+        data_loader_train = data.DataLoader(
+            dataset=train_dataset,
+            batch_size=batch_size,
+            shuffle=shuffle,
+            num_workers=0,
+            collate_fn=None,
+            pin_memory=False,
+        )
 
     data_loader_test = data.DataLoader(
         dataset=val_dataset,
diff --git a/image_ref/main.py b/image_ref/main.py
index 246df886..ba2bb9f9 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -74,7 +74,7 @@ def test_duo(model, data_test, loss_function, epoch):
 def run_duo(args):
     #load data
     data_train, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_test=args.dataset_val_dir, batch_size=args.batch_size,
-                                          ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop)
+                                          ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop, sampler=args.sampler)
 
     #load model
     model = Classification_model_duo_contrastive(model = args.model, n_class=2)
-- 
GitLab