diff --git a/diann_lib_processing.py b/diann_lib_processing.py index 1a95c514440491bbce2e68b6e0a5be222f28ffaf..dbf35f03b41766b4c4a93f8588a5589824c0ffd7 100644 --- a/diann_lib_processing.py +++ b/diann_lib_processing.py @@ -65,7 +65,7 @@ if __name__ =='__main__': model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, - embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first) + embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=30) if torch.cuda.is_available(): model = model.cuda() diff --git a/main.py b/main.py index 9893370015b16305e26f6abaf6e919fcea0afe9e..3d614daa4eba4a2b2594cd5663047858e0e21da2 100644 --- a/main.py +++ b/main.py @@ -103,7 +103,8 @@ def main(args): model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, - embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first) + embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, + seq_length=30) if args.model_weigh is not None : model.load_state_dict(torch.load(args.model_weigh, weights_only=True)) diff --git a/model/model.py b/model/model.py index 0d962cc004111acbaaa77ac5bdca5ac266a2215a..8c74e76bf9d362a1baa4b6bf217f44afe1da24c7 100644 --- a/model/model.py +++ b/model/model.py @@ -38,7 +38,7 @@ class PositionalEncoding(nn.Module): class ModelTransformer(nn.Module): def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=22, regressor_layer_size_rt=512, decoder_rt_ff=512, - n_head=1, seq_length=25, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, acti='relu', + n_head=1, seq_length=30, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, acti='relu', norm=False): self.seq_length = seq_length self.nb_aa = nb_aa