Skip to content
Snippets Groups Projects
Commit faee664f authored by Léo Schneider's avatar Léo Schneider
Browse files

test cossim

parent 7d9d2912
No related branches found
No related tags found
No related merge requests found
...@@ -95,7 +95,7 @@ def alphabetical_to_numerical(seq, vocab): ...@@ -95,7 +95,7 @@ def alphabetical_to_numerical(seq, vocab):
else : else :
for i in range(len(seq) - 2 * seq.count('-')): for i in range(len(seq) - 2 * seq.count('-')):
if seq[i + dec] != '-': if seq[i + dec] != '-':
num.append(IUPAC_VOCAB[seq[i + dec]]) num.append(ALPHABET_UNMOD[seq[i + dec]])
else: else:
if seq[i + dec + 1:i + dec + 4] == 'CaC': if seq[i + dec + 1:i + dec + 4] == 'CaC':
num.append(21) num.append(21)
......
...@@ -25,17 +25,14 @@ class PositionalEncoding(nn.Module): ...@@ -25,17 +25,14 @@ class PositionalEncoding(nn.Module):
pe = torch.zeros(max_len, 1, d_model) pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term)
print(d_model, max_len)
self.register_buffer('pe', pe) self.register_buffer('pe', pe)
def forward(self, x): def forward(self, x):
x = torch.permute(x, (1, 0, 2)) x = torch.permute(x, (1, 0, 2))
""" """
Arguments: Arguments:
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
""" """
x = x + self.pe[:x.size(0)] x = x + self.pe[:x.size(0)]
return self.dropout(x) return self.dropout(x)
...@@ -106,8 +103,6 @@ class Model_Common_Transformer(nn.Module): ...@@ -106,8 +103,6 @@ class Model_Common_Transformer(nn.Module):
def forward(self, seq, charge): def forward(self, seq, charge):
meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
seq_emb = torch.permute(seq_emb, (0, 2, 1))
print(seq_emb.size())
emb = self.pos_embedding(self.emb(seq_emb)) emb = self.pos_embedding(self.emb(seq_emb))
meta_enc = self.meta_enc(meta_ohe) meta_enc = self.meta_enc(meta_ohe)
...@@ -121,7 +116,6 @@ class Model_Common_Transformer(nn.Module): ...@@ -121,7 +116,6 @@ class Model_Common_Transformer(nn.Module):
def forward_rt(self, seq): def forward_rt(self, seq):
seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
seq_emb = torch.permute(seq_emb, (0, 2, 1))
emb = self.pos_embedding(self.emb(seq_emb)) emb = self.pos_embedding(self.emb(seq_emb))
enc = self.encoder(emb) enc = self.encoder(emb)
out_rt = self.decoder_RT(enc) out_rt = self.decoder_RT(enc)
...@@ -131,7 +125,6 @@ class Model_Common_Transformer(nn.Module): ...@@ -131,7 +125,6 @@ class Model_Common_Transformer(nn.Module):
def forward_int(self, seq, charge): def forward_int(self, seq, charge):
meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
seq_emb = torch.permute(seq_emb, (0, 2, 1))
emb = self.pos_embedding(self.emb(seq_emb)) emb = self.pos_embedding(self.emb(seq_emb))
meta_enc = self.meta_enc(meta_ohe) meta_enc = self.meta_enc(meta_ohe)
enc = self.encoder(emb) enc = self.encoder(emb)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment