diff --git a/model_custom.py b/model_custom.py index fc5df00e783bd59b8c7d1baf4b5a04bd1f4735d2..42dbcb4a48f2691d90c435180f14f4f1cedd463c 100644 --- a/model_custom.py +++ b/model_custom.py @@ -41,7 +41,7 @@ class Model_Common_Transformer(nn.Module): def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=23, regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512, n_head=1, seq_length=25, - charge_max=6, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, + charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, decoder_int_num_layer=1, acti='relu', norm=False): self.charge_max = charge_max #TODO filter charge in train to be in 1-4 0-5 atm self.seq_length = seq_length @@ -102,7 +102,7 @@ class Model_Common_Transformer(nn.Module): def forward(self, seq, charge): # print('seq', seq) # print('charge', charge) - meta_ohe = torch.nn.functional.one_hot(charge, self.charge_max).float() + meta_ohe = torch.nn.functional.one_hot(charge-1, self.charge_max).float() # print('meta_ohe', meta_ohe) seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() # print('seq_emb', seq_emb)