From d76c6464cf8712ea02411109b3829f7b556175d0 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 21 Oct 2024 14:15:05 +0200 Subject: [PATCH] datasets --- model_custom.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/model_custom.py b/model_custom.py index abb8685..e4601b3 100644 --- a/model_custom.py +++ b/model_custom.py @@ -101,16 +101,22 @@ class Model_Common_Transformer(nn.Module): d_model=self.embedding_dim) def forward(self, seq, charge): - print(torch.cuda.mem_get_info()) + print('seq', seq.size()) + print('charge', charge.size()) meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() + print('meta_ohe', meta_ohe.size()) seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() + print('seq_emb', seq_emb.size()) emb = self.pos_embedding(self.emb(seq_emb)) + print('emb', emb.size()) meta_enc = self.meta_enc(meta_ohe) - + print('meta_enc', meta_enc.size()) enc = self.encoder(emb) - + print('enc', enc.size()) out_rt = self.decoder_RT(enc) + print('out_rt', out_rt.size()) int_enc = torch.mul(enc, meta_enc) + print('out_rt', out_rt.size()) out_int = self.decoder_int(int_enc) return out_rt.flatten(), out_int -- GitLab