From bbc0dcf2e8252e80b97076ac1618405ae353130f Mon Sep 17 00:00:00 2001 From: schne <leo.schneider@ecl19.ec-lyon.fr> Date: Mon, 9 Sep 2024 13:30:41 +0200 Subject: [PATCH] dataset --- dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataloader.py b/dataloader.py index 121ad91..d3b872e 100644 --- a/dataloader.py +++ b/dataloader.py @@ -77,7 +77,8 @@ class RT_Dataset(Dataset): self.data = self.data[self.data.state == 'holdout'] elif mode == 'validation': self.data = self.data[self.data.state == 'validation'] - + if size is not None: + self.data = self.data.sample(size) print('Padding') self.data['sequence'] = self.data['sequence'].str.pad(length, side='right', fillchar='_') -- GitLab