diff --git a/main_custom.py b/main_custom.py
index db984b4baa652a0b7f364dd7c1ffacd5d81c3f87..fcc907494f8c2419cfd47376f0c5a9f4fac3a94c 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -207,7 +207,7 @@ def main(args):
 
     print('\nData loaded')
 
-    model = Model_Common_Transformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
+    model = Model_Common_Transformer_TAPE(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
                                      decoder_int_ff=args.decoder_int_ff
                                      , n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
                                      decoder_int_num_layer=args.decoder_int_num_layer,