From 6a63684cf21afc86c73ddb2d5cad026391ff14fd Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 11 Feb 2025 12:40:19 +0100 Subject: [PATCH] dataset pred for diann --- main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 3d614da..d3b2cd4 100644 --- a/main.py +++ b/main.py @@ -95,9 +95,9 @@ def main(args): print(args) print('Cuda : ', torch.cuda.is_available()) - data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train, seq_col=args.seq_train) - data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test, seq_col=args.seq_test) - data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val, seq_col=args.seq_val) + data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=30, mode=args.split_train, seq_col=args.seq_train) + data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=30, mode=args.split_test, seq_col=args.seq_test) + data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=30, mode=args.split_val, seq_col=args.seq_val) print('\nData loaded') model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, -- GitLab