diff --git a/config.py b/config.py
index d2e5d2d71f1c0715b27b37cfc00187de72677fee..124672da0befa1b7cca6ee6912ec288cdf5ca8b2 100644
--- a/config.py
+++ b/config.py
@@ -30,6 +30,7 @@ def load_args():
     parser.add_argument('--seq_test', type=str, default='sequence')
     parser.add_argument('--seq_val', type=str, default='sequence')
     parser.add_argument('--n_head', type=int, default=1)
+    parser.add_argument('--model_weigh', type=str, default=None)
     args = parser.parse_args()
 
     return args
diff --git a/main.py b/main.py
index 5113e81b9b78b79636f103b745027e3353535f69..472cdb9986316e56e8ce41b89173c328ff129c9c 100644
--- a/main.py
+++ b/main.py
@@ -104,6 +104,10 @@ def main(args):
                               n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
                               decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
                               embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first)
+
+    if args.model_weigh is not None :
+        model.load_state_dict(torch.load(args.model_weigh+'.pt', weights_only=True))
+
     if torch.cuda.is_available():
         model = model.cuda()
     optimizer = optim.Adam(model.parameters(), lr=args.lr)
@@ -185,3 +189,4 @@ if __name__ == "__main__":
 
 
 
+#output/out_coli_augmented_04_coli_8.pt
\ No newline at end of file