From f05c0502c0e8edc8105077bfbcb195238f2e0af5 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Thu, 23 Jan 2025 16:41:38 +0100 Subject: [PATCH] transfer learning --- config.py | 1 + main.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/config.py b/config.py index d2e5d2d..124672d 100644 --- a/config.py +++ b/config.py @@ -30,6 +30,7 @@ def load_args(): parser.add_argument('--seq_test', type=str, default='sequence') parser.add_argument('--seq_val', type=str, default='sequence') parser.add_argument('--n_head', type=int, default=1) + parser.add_argument('--model_weigh', type=str, default=None) args = parser.parse_args() return args diff --git a/main.py b/main.py index 5113e81..472cdb9 100644 --- a/main.py +++ b/main.py @@ -104,6 +104,10 @@ def main(args): n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first) + + if args.model_weigh is not None : + model.load_state_dict(torch.load(args.model_weigh+'.pt', weights_only=True)) + if torch.cuda.is_available(): model = model.cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr) @@ -185,3 +189,4 @@ if __name__ == "__main__": +#output/out_coli_augmented_04_coli_8.pt \ No newline at end of file -- GitLab