diff --git a/model_custom.py b/model_custom.py index 04f087072a349ce211a869ae88526a97525142d3..f829002b87c54452330ac46d4ce6767c4a2e0e64 100644 --- a/model_custom.py +++ b/model_custom.py @@ -106,7 +106,7 @@ class Model_Common_Transformer(nn.Module): seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() print(seq_emb.shape) print(self.nb_aa, self.embedding_dim) - print(self.emb(seq_emb)) + print(seq_emb) emb = self.pos_embedding(self.emb(seq_emb)) meta_enc = self.meta_enc(meta_ohe)