diff --git a/main.py b/main.py index 936e040d24cb16b9c4fbacd8e64a550ea258a44e..97cecf39b7b11529db2af847707808a58a76f273 100644 --- a/main.py +++ b/main.py @@ -272,7 +272,7 @@ def main_rt(args): data_sources=[args.dataset_train, args.dataset_train, args.dataset_train]) else: data_train, data_val, data_test = load_data(batch_size=args.batch_size, n_train=args.n_train, n_test=args.n_test, - data_sources=[args.dataset_train,args.dataset_test,args.dataset_train]) + data_sources=[args.dataset_train,args.dataset_train,args.dataset_test]) print('\nData loaded') # if args.model == 'RT_self_att' : # model = RT_pred_model_self_attention()