diff --git a/dataloader.py b/dataloader.py index d3b872e24d88f3ad13b482149bbe2d9ed40934fe..81a2d20b8f3d96618b47f833c115d8b65b1a7fc3 100644 --- a/dataloader.py +++ b/dataloader.py @@ -109,7 +109,7 @@ def load_data(batch_size, data_sources, n_train=None, n_test=None, length=30): print('Loading data') train = RT_Dataset(n_train, data_sources[0], 'train', length) test = RT_Dataset(n_test, data_sources[1], 'test', length) - val = RT_Dataset(n_test, data_sources[2], 'validation', length) + val = RT_Dataset(n_test, data_sources[2], 'test', length) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)