diff --git a/model_custom.py b/model_custom.py index ce6bb1c062ac55453bd5af63759ffc42a8160726..0a675021ff2894025bd95d4b2a5f6f59b52f74d3 100644 --- a/model_custom.py +++ b/model_custom.py @@ -25,16 +25,17 @@ class PositionalEncoding(nn.Module): pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) + print(d_model, max_len) self.register_buffer('pe', pe) def forward(self, x): - print(x.size()) + x = torch.permute(x, (1, 0, 2)) """ Arguments: x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` """ - print(self.pe[:x.size(0)].size()) + x = x + self.pe[:x.size(0)] return self.dropout(x)