From e4d23e3e748fc282242bea9be090d095ac6a1049 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 11 Feb 2025 11:57:35 +0100 Subject: [PATCH] dataset pred for diann --- diann_lib_processing.py | 2 +- main.py | 3 ++- model/model.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/diann_lib_processing.py b/diann_lib_processing.py index 1a95c51..dbf35f0 100644 --- a/diann_lib_processing.py +++ b/diann_lib_processing.py @@ -65,7 +65,7 @@ if __name__ =='__main__': model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, - embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first) + embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=30) if torch.cuda.is_available(): model = model.cuda() diff --git a/main.py b/main.py index 9893370..3d614da 100644 --- a/main.py +++ b/main.py @@ -103,7 +103,8 @@ def main(args): model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, - embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first) + embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, + seq_length=30) if args.model_weigh is not None : model.load_state_dict(torch.load(args.model_weigh, weights_only=True)) diff --git a/model/model.py b/model/model.py index 0d962cc..8c74e76 100644 --- a/model/model.py +++ b/model/model.py @@ -38,7 +38,7 @@ class PositionalEncoding(nn.Module): class ModelTransformer(nn.Module): def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=22, regressor_layer_size_rt=512, decoder_rt_ff=512, - n_head=1, seq_length=25, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, acti='relu', + n_head=1, seq_length=30, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, acti='relu', norm=False): self.seq_length = seq_length self.nb_aa = nb_aa -- GitLab