From bbc0dcf2e8252e80b97076ac1618405ae353130f Mon Sep 17 00:00:00 2001
From: schne <leo.schneider@ecl19.ec-lyon.fr>
Date: Mon, 9 Sep 2024 13:30:41 +0200
Subject: [PATCH] dataset

---
 dataloader.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/dataloader.py b/dataloader.py
index 121ad91..d3b872e 100644
--- a/dataloader.py
+++ b/dataloader.py
@@ -77,7 +77,8 @@ class RT_Dataset(Dataset):
             self.data = self.data[self.data.state == 'holdout']
         elif mode == 'validation':
             self.data = self.data[self.data.state == 'validation']
-
+        if size is not None:
+            self.data = self.data.sample(size)
 
         print('Padding')
         self.data['sequence'] = self.data['sequence'].str.pad(length, side='right', fillchar='_')
-- 
GitLab