diff --git a/model_custom.py b/model_custom.py index 2c024d15fb0abf47257aac3030833ee2c1166973..fc5df00e783bd59b8c7d1baf4b5a04bd1f4735d2 100644 --- a/model_custom.py +++ b/model_custom.py @@ -16,10 +16,9 @@ class PermuteLayer(nn.Module): class PositionalEncoding(nn.Module): - def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 26): + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 30): super().__init__() self.dropout = nn.Dropout(p=dropout) - position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) @@ -42,9 +41,9 @@ class Model_Common_Transformer(nn.Module): def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=23, regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512, n_head=1, seq_length=25, - charge_max=5, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, + charge_max=6, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, decoder_int_num_layer=1, acti='relu', norm=False): - self.charge_max = charge_max + self.charge_max = charge_max #TODO filter charge in train to be in 1-4 0-5 atm self.seq_length = seq_length self.nb_aa = nb_aa self.charge_frag_max = charge_frag_max @@ -101,24 +100,24 @@ class Model_Common_Transformer(nn.Module): d_model=self.embedding_dim) def forward(self, seq, charge): - print('seq', seq) - print('charge', charge) - meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() - print('meta_ohe', meta_ohe) + # print('seq', seq) + # print('charge', charge) + meta_ohe = torch.nn.functional.one_hot(charge, self.charge_max).float() + # print('meta_ohe', meta_ohe) seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() - print('seq_emb', seq_emb) + # print('seq_emb', seq_emb) emb = self.pos_embedding(self.emb(seq_emb)) - print('emb', emb) + # print('emb', emb) meta_enc = self.meta_enc(meta_ohe) - print('meta_enc', meta_enc) + # print('meta_enc', meta_enc) enc = self.encoder(emb) - print('enc', enc) + # print('enc', enc) out_rt = self.decoder_RT(enc) - print('out_rt', out_rt) + # print('out_rt', out_rt) int_enc = torch.mul(enc, meta_enc) - print('int_enc', int_enc) + # print('int_enc', int_enc) out_int = self.decoder_int(int_enc) - print('out_int', out_int) + # print('out_int', out_int) return out_rt.flatten(), out_int