Skip to content
Snippets Groups Projects
Commit 940fd506 authored by Schneider Leo's avatar Schneider Leo
Browse files

revert : base dir option for ray

parent 95a866fe
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,6 @@ def load_args_contrastive():
parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive')
parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive')
parser.add_argument('--dataset_test_dir', type=str, default=None)
parser.add_argument('--base_dir', type=str, default=None)
parser.add_argument('--base_out', type=str, default='output/baseline')
parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
......
......@@ -131,19 +131,15 @@ def make_dataset_custom(
class ImageFolderDuo(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, positive_prop=None, ref_transform=None, base_dir=None):
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, positive_prop=None, ref_transform=None):
self.root = root
self.imlist = flist_reader(root)
self.transform = transform
self.target_transform = target_transform
self.ref_transform = ref_transform
self.loader = loader
self.classes = torchvision.datasets.folder.find_classes(root)[0]
if base_dir is not None :
self.ref_dir = os.path.join(base_dir,ref_dir)
self.imlist = flist_reader(os.path.join(base_dir,root))
else :
self.ref_dir = ref_dir
self.imlist = flist_reader(root)
self.ref_dir = ref_dir
self.positive_prop = positive_prop
def __getitem__(self, index):
......@@ -173,7 +169,7 @@ class ImageFolderDuo(data.Dataset):
def __len__(self):
return len(self.imlist)
def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None, base_dir=None):
def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None):
......@@ -200,12 +196,12 @@ def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuff
print('Default val transform')
train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform, base_dir=base_dir)
val_dataset = ImageFolderDuo_Batched(root=base_dir_val, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform, base_dir=base_dir)
train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform)
val_dataset = ImageFolderDuo_Batched(root=base_dir_val, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform)
if base_dir_test is not None :
test_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir=ref_dir,
ref_transform=ref_transform, base_dir=base_dir)
ref_transform=ref_transform)
if sampler =='balanced' :
y_train_label = np.array([i for (_,_,i)in train_dataset.imlist])
......@@ -260,19 +256,15 @@ def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuff
class ImageFolderDuo_Batched(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, ref_transform=None, base_dir=None):
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, ref_transform=None):
self.root = root
if base_dir is not None:
self.ref_dir = os.path.join(base_dir, ref_dir)
self.imlist = flist_reader(os.path.join(base_dir, root))
else:
self.ref_dir = ref_dir
self.imlist = flist_reader(root)
self.imlist = flist_reader(root)
self.transform = transform
self.ref_transform = ref_transform
self.target_transform = target_transform
self.loader = loader
self.classes = torchvision.datasets.folder.find_classes(root)[0]
self.ref_dir = ref_dir
def __getitem__(self, index):
impathAER, impathANA, target = self.imlist[index]
......
......@@ -152,7 +152,6 @@ def test_model(best_result, args):
_, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir,
base_dir_val=args.dataset_val_dir,
base_dir_test=args.dataset_test_dir,
base_dir=args.base_dir,
batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir,
noise_threshold=best_result.config['noise'],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment