From f05c0502c0e8edc8105077bfbcb195238f2e0af5 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Thu, 23 Jan 2025 16:41:38 +0100
Subject: [PATCH] transfer learning

---
 config.py | 1 +
 main.py   | 5 +++++
 2 files changed, 6 insertions(+)

diff --git a/config.py b/config.py
index d2e5d2d..124672d 100644
--- a/config.py
+++ b/config.py
@@ -30,6 +30,7 @@ def load_args():
     parser.add_argument('--seq_test', type=str, default='sequence')
     parser.add_argument('--seq_val', type=str, default='sequence')
     parser.add_argument('--n_head', type=int, default=1)
+    parser.add_argument('--model_weigh', type=str, default=None)
     args = parser.parse_args()
 
     return args
diff --git a/main.py b/main.py
index 5113e81..472cdb9 100644
--- a/main.py
+++ b/main.py
@@ -104,6 +104,10 @@ def main(args):
                               n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
                               decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
                               embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first)
+
+    if args.model_weigh is not None :
+        model.load_state_dict(torch.load(args.model_weigh+'.pt', weights_only=True))
+
     if torch.cuda.is_available():
         model = model.cuda()
     optimizer = optim.Adam(model.parameters(), lr=args.lr)
@@ -185,3 +189,4 @@ if __name__ == "__main__":
 
 
 
+#output/out_coli_augmented_04_coli_8.pt
\ No newline at end of file
-- 
GitLab