diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py
index 5dd8ff3040094f74f290ddd695b26cc77cda0c2e..738820aa194f56a2f5e29188f7e2fc8c4fb26831 100644
--- a/image_ref/dataset_ref.py
+++ b/image_ref/dataset_ref.py
@@ -248,8 +248,8 @@ def load_data_duo_batched(base_dir, shuffle=True, noise_threshold=0, ref_dir = N
          transforms.Normalize(0.5, 0.5)])
     print('Default val transform')
 
-    train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform, ref_dir = ref_dir)
-    val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform, ref_dir = ref_dir)
+    train_dataset = ImageFolderDuo_Batched(root=base_dir, transform=train_transform, ref_dir = ref_dir)
+    val_dataset = ImageFolderDuo_Batched(root=base_dir, transform=val_transform, ref_dir = ref_dir)
     generator1 = torch.Generator().manual_seed(42)
     indices = torch.randperm(len(train_dataset), generator=generator1)
     val_size = len(train_dataset) // 5