diff --git a/model_custom.py b/model_custom.py index 7f8a7d107d9091b5250354416087737799061f97..2c024d15fb0abf47257aac3030833ee2c1166973 100644 --- a/model_custom.py +++ b/model_custom.py @@ -101,15 +101,24 @@ class Model_Common_Transformer(nn.Module): d_model=self.embedding_dim) def forward(self, seq, charge): + print('seq', seq) + print('charge', charge) meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() + print('meta_ohe', meta_ohe) seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() + print('seq_emb', seq_emb) emb = self.pos_embedding(self.emb(seq_emb)) - meta_enc = self.meta_enc(meta_ohe) print('emb', emb) + meta_enc = self.meta_enc(meta_ohe) + print('meta_enc', meta_enc) enc = self.encoder(emb) + print('enc', enc) out_rt = self.decoder_RT(enc) + print('out_rt', out_rt) int_enc = torch.mul(enc, meta_enc) + print('int_enc', int_enc) out_int = self.decoder_int(int_enc) + print('out_int', out_int) return out_rt.flatten(), out_int