Skip to content
Snippets Groups Projects
Commit b29d68ab authored by Léo Schneider's avatar Léo Schneider
Browse files

test cossim

parent 470500a3
No related branches found
No related tags found
No related merge requests found
...@@ -164,30 +164,3 @@ class TransformerEncoder(nn.Module): ...@@ -164,30 +164,3 @@ class TransformerEncoder(nn.Module):
x = l(x) x = l(x)
return attention_maps return attention_maps
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
"""
Inputs
d_model - Hidden dimensionality of the input.
max_len - Maximum length of a sequence to expect.
"""
super().__init__()
# Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# register_buffer => Tensor which is not a parameter, but should be part of the modules state.
# Used for tensors that need to be on the same device as the module.
# persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
self.register_buffer('pe', pe, persistent=False)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return x
...@@ -28,11 +28,13 @@ class PositionalEncoding(nn.Module): ...@@ -28,11 +28,13 @@ class PositionalEncoding(nn.Module):
self.register_buffer('pe', pe) self.register_buffer('pe', pe)
def forward(self, x): def forward(self, x):
print(x.size())
x = torch.permute(x, (1, 0, 2)) x = torch.permute(x, (1, 0, 2))
""" """
Arguments: Arguments:
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
""" """
print(self.pe[:x.size(0)].size())
x = x + self.pe[:x.size(0)] x = x + self.pe[:x.size(0)]
return self.dropout(x) return self.dropout(x)
...@@ -103,6 +105,7 @@ class Model_Common_Transformer(nn.Module): ...@@ -103,6 +105,7 @@ class Model_Common_Transformer(nn.Module):
def forward(self, seq, charge): def forward(self, seq, charge):
meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
print(seq_emb.size())
emb = self.pos_embedding(self.emb(seq_emb)) emb = self.pos_embedding(self.emb(seq_emb))
meta_enc = self.meta_enc(meta_ohe) meta_enc = self.meta_enc(meta_ohe)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment