diff --git a/dataloader.py b/dataloader.py index 121ad91c248fe6d0fb5b302521500b257f9b0d6b..d3b872e24d88f3ad13b482149bbe2d9ed40934fe 100644 --- a/dataloader.py +++ b/dataloader.py @@ -77,7 +77,8 @@ class RT_Dataset(Dataset): self.data = self.data[self.data.state == 'holdout'] elif mode == 'validation': self.data = self.data[self.data.state == 'validation'] - + if size is not None: + self.data = self.data.sample(size) print('Padding') self.data['sequence'] = self.data['sequence'].str.pad(length, side='right', fillchar='_')