From 95a866fe4381f468a7371eea5cd83b20773f948a Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 15 Apr 2025 11:30:16 +0200 Subject: [PATCH] add : base dir option for ray --- image_ref/config.py | 1 + image_ref/dataset_ref.py | 28 ++++++++++++++++++---------- image_ref/main_ray.py | 1 + 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/image_ref/config.py b/image_ref/config.py index 7137615b..d2e11660 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -16,6 +16,7 @@ def load_args_contrastive(): parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive') parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive') parser.add_argument('--dataset_test_dir', type=str, default=None) + parser.add_argument('--base_dir', type=str, default=None) parser.add_argument('--base_out', type=str, default='output/baseline') parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') parser.add_argument('--output', type=str, default='output/out_contrastive.csv') diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py index c48992f6..c73c3c47 100644 --- a/image_ref/dataset_ref.py +++ b/image_ref/dataset_ref.py @@ -131,15 +131,19 @@ def make_dataset_custom( class ImageFolderDuo(data.Dataset): def __init__(self, root, transform=None, target_transform=None, - flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, positive_prop=None, ref_transform=None): + flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, positive_prop=None, ref_transform=None, base_dir=None): self.root = root - self.imlist = flist_reader(root) self.transform = transform self.target_transform = target_transform self.ref_transform = ref_transform self.loader = loader self.classes = torchvision.datasets.folder.find_classes(root)[0] - self.ref_dir = ref_dir + if base_dir is not None : + self.ref_dir = os.path.join(base_dir,ref_dir) + self.imlist = flist_reader(os.path.join(base_dir,root)) + else : + self.ref_dir = ref_dir + self.imlist = flist_reader(root) self.positive_prop = positive_prop def __getitem__(self, index): @@ -169,7 +173,7 @@ class ImageFolderDuo(data.Dataset): def __len__(self): return len(self.imlist) -def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None): +def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None, base_dir=None): @@ -196,12 +200,12 @@ def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuff print('Default val transform') - train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform) - val_dataset = ImageFolderDuo_Batched(root=base_dir_val, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform) + train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform, base_dir=base_dir) + val_dataset = ImageFolderDuo_Batched(root=base_dir_val, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform, base_dir=base_dir) if base_dir_test is not None : test_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir=ref_dir, - ref_transform=ref_transform) + ref_transform=ref_transform, base_dir=base_dir) if sampler =='balanced' : y_train_label = np.array([i for (_,_,i)in train_dataset.imlist]) @@ -256,15 +260,19 @@ def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuff class ImageFolderDuo_Batched(data.Dataset): def __init__(self, root, transform=None, target_transform=None, - flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, ref_transform=None): + flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, ref_transform=None, base_dir=None): self.root = root - self.imlist = flist_reader(root) + if base_dir is not None: + self.ref_dir = os.path.join(base_dir, ref_dir) + self.imlist = flist_reader(os.path.join(base_dir, root)) + else: + self.ref_dir = ref_dir + self.imlist = flist_reader(root) self.transform = transform self.ref_transform = ref_transform self.target_transform = target_transform self.loader = loader self.classes = torchvision.datasets.folder.find_classes(root)[0] - self.ref_dir = ref_dir def __getitem__(self, index): impathAER, impathANA, target = self.imlist[index] diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index 90f45583..04ab5a1a 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -152,6 +152,7 @@ def test_model(best_result, args): _, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_val=args.dataset_val_dir, base_dir_test=args.dataset_test_dir, + base_dir=args.base_dir, batch_size=args.batch_size, ref_dir=args.dataset_ref_dir, noise_threshold=best_result.config['noise'], -- GitLab