Skip to content
Snippets Groups Projects
Commit f05c0502 authored by Schneider Leo's avatar Schneider Leo
Browse files

transfer learning

parent 2d98534b
No related branches found
No related tags found
No related merge requests found
...@@ -30,6 +30,7 @@ def load_args(): ...@@ -30,6 +30,7 @@ def load_args():
parser.add_argument('--seq_test', type=str, default='sequence') parser.add_argument('--seq_test', type=str, default='sequence')
parser.add_argument('--seq_val', type=str, default='sequence') parser.add_argument('--seq_val', type=str, default='sequence')
parser.add_argument('--n_head', type=int, default=1) parser.add_argument('--n_head', type=int, default=1)
parser.add_argument('--model_weigh', type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -104,6 +104,10 @@ def main(args): ...@@ -104,6 +104,10 @@ def main(args):
n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first) embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first)
if args.model_weigh is not None :
model.load_state_dict(torch.load(args.model_weigh+'.pt', weights_only=True))
if torch.cuda.is_available(): if torch.cuda.is_available():
model = model.cuda() model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
...@@ -185,3 +189,4 @@ if __name__ == "__main__": ...@@ -185,3 +189,4 @@ if __name__ == "__main__":
#output/out_coli_augmented_04_coli_8.pt
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment