From 74beea9aa5f3583fc15271650da6419b38aac253 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 21 Oct 2024 14:03:24 +0200 Subject: [PATCH] datasets --- model_custom.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/model_custom.py b/model_custom.py index 30fb32f..abb8685 100644 --- a/model_custom.py +++ b/model_custom.py @@ -42,7 +42,7 @@ class Model_Common_Transformer(nn.Module): def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=23, regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512, n_head=1, seq_length=25, - charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, + charge_max=5, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, decoder_int_num_layer=1, acti='relu', norm=False): self.charge_max = charge_max self.seq_length = seq_length @@ -101,14 +101,9 @@ class Model_Common_Transformer(nn.Module): d_model=self.embedding_dim) def forward(self, seq, charge): - print(seq.size(),charge.size()) - print(seq, charge) print(torch.cuda.mem_get_info()) meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() - print(seq_emb.shape) - print(self.nb_aa, self.embedding_dim) - print(seq_emb) emb = self.pos_embedding(self.emb(seq_emb)) meta_enc = self.meta_enc(meta_ohe) -- GitLab