diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py index 5dd8ff3040094f74f290ddd695b26cc77cda0c2e..738820aa194f56a2f5e29188f7e2fc8c4fb26831 100644 --- a/image_ref/dataset_ref.py +++ b/image_ref/dataset_ref.py @@ -248,8 +248,8 @@ def load_data_duo_batched(base_dir, shuffle=True, noise_threshold=0, ref_dir = N transforms.Normalize(0.5, 0.5)]) print('Default val transform') - train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform, ref_dir = ref_dir) - val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform, ref_dir = ref_dir) + train_dataset = ImageFolderDuo_Batched(root=base_dir, transform=train_transform, ref_dir = ref_dir) + val_dataset = ImageFolderDuo_Batched(root=base_dir, transform=val_transform, ref_dir = ref_dir) generator1 = torch.Generator().manual_seed(42) indices = torch.randperm(len(train_dataset), generator=generator1) val_size = len(train_dataset) // 5