diff --git a/common_dataset.py b/common_dataset.py index 9e293f17be4dcc4e25317500ed3946982d8f0808..671048ef66e712f592a82327bf6ea89c0e246613 100644 --- a/common_dataset.py +++ b/common_dataset.py @@ -152,8 +152,8 @@ class Common_Dataset(Dataset): def load_data(path_train, path_val, path_test, batch_size, length, pad=False, convert=False, vocab = 'unmod'): print('Loading data') - data_val = pd.read_pickle(path_val) data_train = pd.read_pickle(path_train) + data_val = pd.read_pickle(path_val) data_test = pd.read_pickle(path_test) train = Common_Dataset(data_train, length, pad, convert, vocab) test = Common_Dataset(data_val, length, pad, convert, vocab)