diff --git a/model_custom.py b/model_custom.py
index 2c024d15fb0abf47257aac3030833ee2c1166973..fc5df00e783bd59b8c7d1baf4b5a04bd1f4735d2 100644
--- a/model_custom.py
+++ b/model_custom.py
@@ -16,10 +16,9 @@ class PermuteLayer(nn.Module):
 
 class PositionalEncoding(nn.Module):
 
-    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 26):
+    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 30):
         super().__init__()
         self.dropout = nn.Dropout(p=dropout)
-
         position = torch.arange(max_len).unsqueeze(1)
         div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
         pe = torch.zeros(max_len, 1, d_model)
@@ -42,9 +41,9 @@ class Model_Common_Transformer(nn.Module):
     def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=23,
                  regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512,
                  n_head=1, seq_length=25,
-                 charge_max=5, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1,
+                 charge_max=6, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1,
                  decoder_int_num_layer=1, acti='relu', norm=False):
-        self.charge_max = charge_max
+        self.charge_max = charge_max #TODO filter charge in train to be in 1-4 0-5 atm
         self.seq_length = seq_length
         self.nb_aa = nb_aa
         self.charge_frag_max = charge_frag_max
@@ -101,24 +100,24 @@ class Model_Common_Transformer(nn.Module):
                                                 d_model=self.embedding_dim)
 
     def forward(self, seq, charge):
-        print('seq', seq)
-        print('charge', charge)
-        meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
-        print('meta_ohe', meta_ohe)
+        # print('seq', seq)
+        # print('charge', charge)
+        meta_ohe = torch.nn.functional.one_hot(charge, self.charge_max).float()
+        # print('meta_ohe', meta_ohe)
         seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
-        print('seq_emb', seq_emb)
+        # print('seq_emb', seq_emb)
         emb = self.pos_embedding(self.emb(seq_emb))
-        print('emb', emb)
+        # print('emb', emb)
         meta_enc = self.meta_enc(meta_ohe)
-        print('meta_enc', meta_enc)
+        # print('meta_enc', meta_enc)
         enc = self.encoder(emb)
-        print('enc', enc)
+        # print('enc', enc)
         out_rt = self.decoder_RT(enc)
-        print('out_rt', out_rt)
+        # print('out_rt', out_rt)
         int_enc = torch.mul(enc, meta_enc)
-        print('int_enc', int_enc)
+        # print('int_enc', int_enc)
         out_int = self.decoder_int(int_enc)
-        print('out_int', out_int)
+        # print('out_int', out_int)
 
         return out_rt.flatten(), out_int