diff --git a/dataloader.py b/dataloader.py
index 121ad91c248fe6d0fb5b302521500b257f9b0d6b..d3b872e24d88f3ad13b482149bbe2d9ed40934fe 100644
--- a/dataloader.py
+++ b/dataloader.py
@@ -77,7 +77,8 @@ class RT_Dataset(Dataset):
             self.data = self.data[self.data.state == 'holdout']
         elif mode == 'validation':
             self.data = self.data[self.data.state == 'validation']
-
+        if size is not None:
+            self.data = self.data.sample(size)
 
         print('Padding')
         self.data['sequence'] = self.data['sequence'].str.pad(length, side='right', fillchar='_')