diff --git a/image_ref/config.py b/image_ref/config.py index aac9449d288fb1a668c82291ad996c39ec5665bc..2d713c206a388e8020d7b982c73329616c7e152f 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -12,6 +12,7 @@ def load_args_contrastive(): parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--positive_prop', type=int, default=30) parser.add_argument('--model', type=str, default='ResNet18') + parser.add_argument('--sampler', type=str, default=None) parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive') parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive') parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py index f2bd8fcca265441697a533676454d0c2fb071592..e5dc9871e532a646d4850db905c5148620c3c575 100644 --- a/image_ref/dataset_ref.py +++ b/image_ref/dataset_ref.py @@ -12,6 +12,8 @@ from typing import Callable, cast, Dict, List, Optional, Tuple, Union from pathlib import Path from collections import OrderedDict +from torch.utils.data import WeightedRandomSampler + IMG_EXTENSIONS = ".npy" class Threshold_noise: @@ -152,7 +154,9 @@ class ImageFolderDuo(data.Dataset): def __len__(self): return len(self.imlist) -def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None): +def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None): + + train_transform = transforms.Compose( [transforms.Resize((224, 224)), @@ -180,14 +184,32 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform) val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform) - data_loader_train = data.DataLoader( - dataset=train_dataset, - batch_size=batch_size, - shuffle=shuffle, - num_workers=0, - collate_fn=None, - pin_memory=False, - ) + if sampler =='weighted' : + y_train_label = np.array([i for (_,_,i)in train_dataset.imlist]) + class_sample_count = np.array([len(np.where(y_train_label == t)[0]) for t in np.unique(y_train_label)]) + weight = 1. / class_sample_count + samples_weight = np.array([weight[t] for t in y_train_label]) + + samples_weight = torch.from_numpy(samples_weight) + sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight)) + + data_loader_train = data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + else : + data_loader_train = data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) data_loader_test = data.DataLoader( dataset=val_dataset, diff --git a/image_ref/main.py b/image_ref/main.py index 246df8861c5ad6abe6d27078ce83c30dd4ac9d29..ba2bb9f9ffa7d99b01731d97a4c1bea71e253e6d 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -74,7 +74,7 @@ def test_duo(model, data_test, loss_function, epoch): def run_duo(args): #load data data_train, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_test=args.dataset_val_dir, batch_size=args.batch_size, - ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop) + ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop, sampler=args.sampler) #load model model = Classification_model_duo_contrastive(model = args.model, n_class=2)