diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py
index ae54025638fe176a54d74cde0b6610666bdf63c9..f0f329ce1b9a15fc3471e426d03daa94164f9dc6 100644
--- a/image_ref/dataset_ref.py
+++ b/image_ref/dataset_ref.py
@@ -225,8 +225,8 @@ class ImageFolderDuo_Batched(data.Dataset):
 
         batched_im_ref = torch.concat(img_refs,dim=0)
         batched_label = torch.tensor(label_refs)
-        batched_imgAER = imgAER.repeat(len(self.classes))
-        batched_imgANA = imgANA.repeat(len(self.classes))
+        batched_imgAER = imgAER.repeat(len(self.classes),1,1)
+        batched_imgANA = imgANA.repeat(len(self.classes),1,1)
 
         return batched_imgAER, batched_imgANA, batched_im_ref, batched_label