diff --git a/model_custom.py b/model_custom.py index a7947f5080ba6b2f2a26f12fab38ccb1bda9427d..dc73f4c034d139cbdf698ac4f3478922266d408f 100644 --- a/model_custom.py +++ b/model_custom.py @@ -25,11 +25,12 @@ class PositionalEncoding(nn.Module): pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) + print(d_model, max_len) self.register_buffer('pe', pe) def forward(self, x): - + x = torch.permute(x, (1, 0, 2)) """ Arguments: x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` @@ -119,6 +120,7 @@ class Model_Common_Transformer(nn.Module): def forward_rt(self, seq): seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() + seq_emb = torch.permute(seq_emb, (0, 2, 1)) emb = self.pos_embedding(self.emb(seq_emb)) enc = self.encoder(emb) out_rt = self.decoder_RT(enc)