From 24c4752193837561b4118fd1d95afbdfee6bb8dc Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 7 Apr 2025 17:08:56 +0200 Subject: [PATCH] add weigthed sampler for training --- image_ref/config.py | 1 + image_ref/dataset_ref.py | 40 +++++++++++++++++++++++++++++++--------- image_ref/main.py | 2 +- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/image_ref/config.py b/image_ref/config.py index aac9449d..2d713c20 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -12,6 +12,7 @@ def load_args_contrastive(): parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--positive_prop', type=int, default=30) parser.add_argument('--model', type=str, default='ResNet18') + parser.add_argument('--sampler', type=str, default=None) parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive') parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive') parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py index f2bd8fcc..e5dc9871 100644 --- a/image_ref/dataset_ref.py +++ b/image_ref/dataset_ref.py @@ -12,6 +12,8 @@ from typing import Callable, cast, Dict, List, Optional, Tuple, Union from pathlib import Path from collections import OrderedDict +from torch.utils.data import WeightedRandomSampler + IMG_EXTENSIONS = ".npy" class Threshold_noise: @@ -152,7 +154,9 @@ class ImageFolderDuo(data.Dataset): def __len__(self): return len(self.imlist) -def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None): +def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None): + + train_transform = transforms.Compose( [transforms.Resize((224, 224)), @@ -180,14 +184,32 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform) val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform) - data_loader_train = data.DataLoader( - dataset=train_dataset, - batch_size=batch_size, - shuffle=shuffle, - num_workers=0, - collate_fn=None, - pin_memory=False, - ) + if sampler =='weighted' : + y_train_label = np.array([i for (_,_,i)in train_dataset.imlist]) + class_sample_count = np.array([len(np.where(y_train_label == t)[0]) for t in np.unique(y_train_label)]) + weight = 1. / class_sample_count + samples_weight = np.array([weight[t] for t in y_train_label]) + + samples_weight = torch.from_numpy(samples_weight) + sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight)) + + data_loader_train = data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + else : + data_loader_train = data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) data_loader_test = data.DataLoader( dataset=val_dataset, diff --git a/image_ref/main.py b/image_ref/main.py index 246df886..ba2bb9f9 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -74,7 +74,7 @@ def test_duo(model, data_test, loss_function, epoch): def run_duo(args): #load data data_train, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_test=args.dataset_val_dir, batch_size=args.batch_size, - ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop) + ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop, sampler=args.sampler) #load model model = Classification_model_duo_contrastive(model = args.model, n_class=2) -- GitLab