From aba4f1cd12248ce6b21d275bf3ed8387dc812393 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 15 Oct 2024 13:48:11 +0200
Subject: [PATCH] seq_length args

---
 config_common.py |  3 +++
 main_custom.py   | 15 ++++++++-------
 2 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/config_common.py b/config_common.py
index 09dea8f..3b9a019 100644
--- a/config_common.py
+++ b/config_common.py
@@ -1,5 +1,7 @@
 import argparse
 
+from tensorflow.python.keras.utils.generic_utils import default
+
 
 def load_args():
     parser = argparse.ArgumentParser()
@@ -27,6 +29,7 @@ def load_args():
     parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction)
     parser.add_argument('--activation', type=str,default='relu')
     parser.add_argument('--file', action=argparse.BooleanOptionalAction)
+    parser.add_argument('--seq_length', type=int, default=25)
     args = parser.parse_args()
 
     return args
diff --git a/main_custom.py b/main_custom.py
index 24d11ff..8aa621b 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -223,28 +223,28 @@ def main(args):
         data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
                                                                    path_val=args.dataset_val,
                                                                    path_test=args.dataset_test,
-                                                                   batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='unmod')
+                                                                   batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
     elif args.forward == 'rt':
         data_train, data_val, data_test = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
-                                                               batch_size=args.batch_size, length=25)
+                                                               batch_size=args.batch_size, length=args.seq_length)
 
     elif args.forward == 'transfer':
         data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'],
-                                                               batch_size=args.batch_size, length=25)
+                                                               batch_size=args.batch_size, length=args.seq_length)
 
         _, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val,
                                                                    path_val=args.dataset_val,
                                                                    path_test=args.dataset_test,
-                                                                   batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='unmod')
+                                                                   batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
 
     elif args.forward == 'reverse':
         _, data_val, data_test = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test],
-                                                               batch_size=args.batch_size, length=25)
+                                                               batch_size=args.batch_size, length=args.seq_length)
 
         data_train, _, _ = common_dataset.load_data(path_train=args.dataset_train,
                                                                    path_val=args.dataset_train,
                                                                    path_test=args.dataset_train,
-                                                                   batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='unmod')
+                                                                   batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
 
     print('\nData loaded')
 
@@ -253,7 +253,8 @@ def main(args):
                                      , n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
                                      decoder_int_num_layer=args.decoder_int_num_layer,
                                      decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
-                                     embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first)
+                                     embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first,
+                                     seq_length=args.seq_length)
     if torch.cuda.is_available():
         model = model.cuda()
     optimizer = optim.Adam(model.parameters(), lr=args.lr)
-- 
GitLab