diff --git a/main_custom.py b/main_custom.py index db984b4baa652a0b7f364dd7c1ffacd5d81c3f87..fcc907494f8c2419cfd47376f0c5a9f4fac3a94c 100644 --- a/main_custom.py +++ b/main_custom.py @@ -207,7 +207,7 @@ def main(args): print('\nData loaded') - model = Model_Common_Transformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, + model = Model_Common_Transformer_TAPE(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, decoder_int_ff=args.decoder_int_ff , n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_int_num_layer=args.decoder_int_num_layer,