diff --git a/config.py b/config.py index d2e5d2d71f1c0715b27b37cfc00187de72677fee..124672da0befa1b7cca6ee6912ec288cdf5ca8b2 100644 --- a/config.py +++ b/config.py @@ -30,6 +30,7 @@ def load_args(): parser.add_argument('--seq_test', type=str, default='sequence') parser.add_argument('--seq_val', type=str, default='sequence') parser.add_argument('--n_head', type=int, default=1) + parser.add_argument('--model_weigh', type=str, default=None) args = parser.parse_args() return args diff --git a/main.py b/main.py index 5113e81b9b78b79636f103b745027e3353535f69..472cdb9986316e56e8ce41b89173c328ff129c9c 100644 --- a/main.py +++ b/main.py @@ -104,6 +104,10 @@ def main(args): n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first) + + if args.model_weigh is not None : + model.load_state_dict(torch.load(args.model_weigh+'.pt', weights_only=True)) + if torch.cuda.is_available(): model = model.cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr) @@ -185,3 +189,4 @@ if __name__ == "__main__": +#output/out_coli_augmented_04_coli_8.pt \ No newline at end of file