Skip to content
Snippets Groups Projects
Commit aba4f1cd authored by Schneider Leo's avatar Schneider Leo
Browse files

seq_length args

parent 538eee31
No related branches found
No related tags found
No related merge requests found
import argparse
from tensorflow.python.keras.utils.generic_utils import default
def load_args():
parser = argparse.ArgumentParser()
......@@ -27,6 +29,7 @@ def load_args():
parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction)
parser.add_argument('--activation', type=str,default='relu')
parser.add_argument('--file', action=argparse.BooleanOptionalAction)
parser.add_argument('--seq_length', type=int, default=25)
args = parser.parse_args()
return args
......@@ -223,28 +223,28 @@ def main(args):
data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
path_val=args.dataset_val,
path_test=args.dataset_test,
batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='unmod')
batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
elif args.forward == 'rt':
data_train, data_val, data_test = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
batch_size=args.batch_size, length=25)
batch_size=args.batch_size, length=args.seq_length)
elif args.forward == 'transfer':
data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'],
batch_size=args.batch_size, length=25)
batch_size=args.batch_size, length=args.seq_length)
_, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val,
path_val=args.dataset_val,
path_test=args.dataset_test,
batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='unmod')
batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
elif args.forward == 'reverse':
_, data_val, data_test = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test],
batch_size=args.batch_size, length=25)
batch_size=args.batch_size, length=args.seq_length)
data_train, _, _ = common_dataset.load_data(path_train=args.dataset_train,
path_val=args.dataset_train,
path_test=args.dataset_train,
batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='unmod')
batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
print('\nData loaded')
......@@ -253,7 +253,8 @@ def main(args):
, n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
decoder_int_num_layer=args.decoder_int_num_layer,
decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first)
embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first,
seq_length=args.seq_length)
if torch.cuda.is_available():
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment