From aba4f1cd12248ce6b21d275bf3ed8387dc812393 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 15 Oct 2024 13:48:11 +0200 Subject: [PATCH] seq_length args --- config_common.py | 3 +++ main_custom.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/config_common.py b/config_common.py index 09dea8f..3b9a019 100644 --- a/config_common.py +++ b/config_common.py @@ -1,5 +1,7 @@ import argparse +from tensorflow.python.keras.utils.generic_utils import default + def load_args(): parser = argparse.ArgumentParser() @@ -27,6 +29,7 @@ def load_args(): parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction) parser.add_argument('--activation', type=str,default='relu') parser.add_argument('--file', action=argparse.BooleanOptionalAction) + parser.add_argument('--seq_length', type=int, default=25) args = parser.parse_args() return args diff --git a/main_custom.py b/main_custom.py index 24d11ff..8aa621b 100644 --- a/main_custom.py +++ b/main_custom.py @@ -223,28 +223,28 @@ def main(args): data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train, path_val=args.dataset_val, path_test=args.dataset_test, - batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='unmod') + batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod') elif args.forward == 'rt': data_train, data_val, data_test = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test], - batch_size=args.batch_size, length=25) + batch_size=args.batch_size, length=args.seq_length) elif args.forward == 'transfer': data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'], - batch_size=args.batch_size, length=25) + batch_size=args.batch_size, length=args.seq_length) _, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val, path_val=args.dataset_val, path_test=args.dataset_test, - batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='unmod') + batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod') elif args.forward == 'reverse': _, data_val, data_test = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test], - batch_size=args.batch_size, length=25) + batch_size=args.batch_size, length=args.seq_length) data_train, _, _ = common_dataset.load_data(path_train=args.dataset_train, path_val=args.dataset_train, path_test=args.dataset_train, - batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='unmod') + batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod') print('\nData loaded') @@ -253,7 +253,8 @@ def main(args): , n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_int_num_layer=args.decoder_int_num_layer, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, - embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first) + embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, + seq_length=args.seq_length) if torch.cuda.is_available(): model = model.cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr) -- GitLab