From 2f95dc5c17688f1fbd9a579537552a6de63a96e0 Mon Sep 17 00:00:00 2001
From: lschneider <leo.schneider@univ-lyon1.fr>
Date: Thu, 19 Sep 2024 15:19:12 +0200
Subject: [PATCH] test cossim

---
 model_custom.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/model_custom.py b/model_custom.py
index a7947f5..dc73f4c 100644
--- a/model_custom.py
+++ b/model_custom.py
@@ -25,11 +25,12 @@ class PositionalEncoding(nn.Module):
         pe = torch.zeros(max_len, 1, d_model)
         pe[:, 0, 0::2] = torch.sin(position * div_term)
         pe[:, 0, 1::2] = torch.cos(position * div_term)
+        print(d_model, max_len)
         self.register_buffer('pe', pe)
 
     def forward(self, x):
 
-
+        x = torch.permute(x, (1, 0, 2))
         """
         Arguments:
             x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
@@ -119,6 +120,7 @@ class Model_Common_Transformer(nn.Module):
 
     def forward_rt(self, seq):
         seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
+        seq_emb = torch.permute(seq_emb, (0, 2, 1))
         emb = self.pos_embedding(self.emb(seq_emb))
         enc = self.encoder(emb)
         out_rt = self.decoder_RT(enc)
-- 
GitLab