From faee664f733cfb4ce1bf2c576968c2cc6b309ada Mon Sep 17 00:00:00 2001
From: lschneider <leo.schneider@univ-lyon1.fr>
Date: Thu, 19 Sep 2024 15:34:20 +0200
Subject: [PATCH] test cossim

---
 common_dataset.py | 2 +-
 model_custom.py   | 7 -------
 2 files changed, 1 insertion(+), 8 deletions(-)

diff --git a/common_dataset.py b/common_dataset.py
index 6e46845..5ab8a5f 100644
--- a/common_dataset.py
+++ b/common_dataset.py
@@ -95,7 +95,7 @@ def alphabetical_to_numerical(seq, vocab):
     else :
         for i in range(len(seq) - 2 * seq.count('-')):
             if seq[i + dec] != '-':
-                num.append(IUPAC_VOCAB[seq[i + dec]])
+                num.append(ALPHABET_UNMOD[seq[i + dec]])
             else:
                 if seq[i + dec + 1:i + dec + 4] == 'CaC':
                     num.append(21)
diff --git a/model_custom.py b/model_custom.py
index d2888b4..96cfc64 100644
--- a/model_custom.py
+++ b/model_custom.py
@@ -25,17 +25,14 @@ class PositionalEncoding(nn.Module):
         pe = torch.zeros(max_len, 1, d_model)
         pe[:, 0, 0::2] = torch.sin(position * div_term)
         pe[:, 0, 1::2] = torch.cos(position * div_term)
-        print(d_model, max_len)
         self.register_buffer('pe', pe)
 
     def forward(self, x):
-
         x = torch.permute(x, (1, 0, 2))
         """
         Arguments:
             x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
         """
-
         x = x + self.pe[:x.size(0)]
         return self.dropout(x)
 
@@ -106,8 +103,6 @@ class Model_Common_Transformer(nn.Module):
     def forward(self, seq, charge):
         meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
         seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
-        seq_emb = torch.permute(seq_emb, (0, 2, 1))
-        print(seq_emb.size())
         emb = self.pos_embedding(self.emb(seq_emb))
         meta_enc = self.meta_enc(meta_ohe)
 
@@ -121,7 +116,6 @@ class Model_Common_Transformer(nn.Module):
 
     def forward_rt(self, seq):
         seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
-        seq_emb = torch.permute(seq_emb, (0, 2, 1))
         emb = self.pos_embedding(self.emb(seq_emb))
         enc = self.encoder(emb)
         out_rt = self.decoder_RT(enc)
@@ -131,7 +125,6 @@ class Model_Common_Transformer(nn.Module):
     def forward_int(self, seq, charge):
         meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
         seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
-        seq_emb = torch.permute(seq_emb, (0, 2, 1))
         emb = self.pos_embedding(self.emb(seq_emb))
         meta_enc = self.meta_enc(meta_ohe)
         enc = self.encoder(emb)
-- 
GitLab