From 5d66520d001b00d95b7b136397c10a6fdff61c4a Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Fri, 20 Sep 2024 11:52:57 +0200 Subject: [PATCH] dataloader fix --- dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataloader.py b/dataloader.py index d3b872e..81a2d20 100644 --- a/dataloader.py +++ b/dataloader.py @@ -109,7 +109,7 @@ def load_data(batch_size, data_sources, n_train=None, n_test=None, length=30): print('Loading data') train = RT_Dataset(n_train, data_sources[0], 'train', length) test = RT_Dataset(n_test, data_sources[1], 'test', length) - val = RT_Dataset(n_test, data_sources[2], 'validation', length) + val = RT_Dataset(n_test, data_sources[2], 'test', length) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=True) -- GitLab