diff --git a/model_custom.py b/model_custom.py index 0a675021ff2894025bd95d4b2a5f6f59b52f74d3..b473a04fb7742cda80ae2d3e78876304810759aa 100644 --- a/model_custom.py +++ b/model_custom.py @@ -30,7 +30,7 @@ class PositionalEncoding(nn.Module): def forward(self, x): - x = torch.permute(x, (1, 0, 2)) + """ Arguments: x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``