From b29d68ab4ed964f30089597f4517abc20b904cd1 Mon Sep 17 00:00:00 2001 From: lschneider <leo.schneider@univ-lyon1.fr> Date: Thu, 19 Sep 2024 14:41:17 +0200 Subject: [PATCH] test cossim --- layers.py | 27 --------------------------- model_custom.py | 3 +++ 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/layers.py b/layers.py index c1ca54a..6c3f94f 100644 --- a/layers.py +++ b/layers.py @@ -164,30 +164,3 @@ class TransformerEncoder(nn.Module): x = l(x) return attention_maps - -class PositionalEncoding(nn.Module): - - def __init__(self, d_model, max_len=5000): - """ - Inputs - d_model - Hidden dimensionality of the input. - max_len - Maximum length of a sequence to expect. - """ - super().__init__() - - # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - - # register_buffer => Tensor which is not a parameter, but should be part of the modules state. - # Used for tensors that need to be on the same device as the module. - # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model) - self.register_buffer('pe', pe, persistent=False) - - def forward(self, x): - x = x + self.pe[:, :x.size(1)] - return x diff --git a/model_custom.py b/model_custom.py index 96cfc64..ce6bb1c 100644 --- a/model_custom.py +++ b/model_custom.py @@ -28,11 +28,13 @@ class PositionalEncoding(nn.Module): self.register_buffer('pe', pe) def forward(self, x): + print(x.size()) x = torch.permute(x, (1, 0, 2)) """ Arguments: x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` """ + print(self.pe[:x.size(0)].size()) x = x + self.pe[:x.size(0)] return self.dropout(x) @@ -103,6 +105,7 @@ class Model_Common_Transformer(nn.Module): def forward(self, seq, charge): meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() + print(seq_emb.size()) emb = self.pos_embedding(self.emb(seq_emb)) meta_enc = self.meta_enc(meta_ohe) -- GitLab