From e4c26d5ecdb6e8b1cb4834520adb6d7554dc3d9a Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 22 Oct 2024 14:19:06 +0200
Subject: [PATCH] datasets

---
 model_custom.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/model_custom.py b/model_custom.py
index fc5df00..42dbcb4 100644
--- a/model_custom.py
+++ b/model_custom.py
@@ -41,7 +41,7 @@ class Model_Common_Transformer(nn.Module):
     def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=23,
                  regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512,
                  n_head=1, seq_length=25,
-                 charge_max=6, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1,
+                 charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1,
                  decoder_int_num_layer=1, acti='relu', norm=False):
         self.charge_max = charge_max #TODO filter charge in train to be in 1-4 0-5 atm
         self.seq_length = seq_length
@@ -102,7 +102,7 @@ class Model_Common_Transformer(nn.Module):
     def forward(self, seq, charge):
         # print('seq', seq)
         # print('charge', charge)
-        meta_ohe = torch.nn.functional.one_hot(charge, self.charge_max).float()
+        meta_ohe = torch.nn.functional.one_hot(charge-1, self.charge_max).float()
         # print('meta_ohe', meta_ohe)
         seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
         # print('seq_emb', seq_emb)
-- 
GitLab