Skip to content
Snippets Groups Projects
model_custom.py 10.12 KiB
import math
import torch.nn as nn
import torch
from tape import TAPETokenizer
from tape.models.modeling_bert import ProteinBertModel

class PermuteLayer(nn.Module):
    def __init__(self, dims):
        super().__init__()
        self.dims = dims

    def forward(self, x):
        x = torch.permute(x, self.dims)
        return x


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 26):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = torch.permute(x, (1, 0, 2))
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class Model_Common_Transformer(nn.Module):

    def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=21,
                 regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512,
                 n_head=1, seq_length=25,
                 charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1,
                 decoder_int_num_layer=1, acti='relu', norm=False):
        self.charge_max = charge_max
        self.seq_length = seq_length
        self.nb_aa = nb_aa
        self.charge_frag_max = charge_frag_max
        self.n_head = n_head
        self.embedding_dim = embedding_dim
        self.encoder_ff = encoder_ff
        self.encoder_num_layer = encoder_num_layer
        self.decoder_rt_ff = decoder_rt_ff
        self.decoder_rt_num_layer = decoder_rt_num_layer
        self.regressor_layer_size_rt = regressor_layer_size_rt
        self.decoder_int_ff = decoder_int_ff
        self.decoder_int_num_layer = decoder_int_num_layer
        self.regressor_layer_size_int = regressor_layer_size_int
        self.drop_rate = drop_rate
        super(Model_Common_Transformer, self).__init__()

        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
                                                                        dim_feedforward=self.encoder_ff,
                                                                        dropout=self.drop_rate, activation=acti,
                                                                        norm_first=norm),
                                             num_layers=self.encoder_num_layer)

        self.meta_enc = nn.Linear(self.charge_max, self.embedding_dim)

        self.decoder_RT = nn.Sequential(
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
                                                             dim_feedforward=self.decoder_rt_ff,
                                                             dropout=self.drop_rate, activation=acti, norm_first=norm),
                                  num_layers=self.decoder_rt_num_layer),
            PermuteLayer((1, 0, 2)),
            nn.Flatten(),
            nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_rt),
            nn.ReLU(),
            nn.Dropout(p=self.drop_rate),
            nn.Linear(self.regressor_layer_size_rt, 1)
        )

        self.decoder_int = nn.Sequential(
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
                                                             dim_feedforward=self.decoder_int_ff,
                                                             dropout=self.drop_rate, activation=acti, norm_first=norm),
                                  num_layers=self.decoder_int_num_layer),
            PermuteLayer((1, 0, 2)),
            nn.Flatten(),
            nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_int),
            nn.ReLU(),
            nn.Dropout(p=self.drop_rate),
            nn.Linear(self.regressor_layer_size_int, (self.seq_length - 1) * self.charge_frag_max * 2)
        )

        self.emb = nn.Linear(self.nb_aa, self.embedding_dim)

        self.pos_embedding = PositionalEncoding(max_len=self.seq_length, dropout=self.drop_rate,
                                                d_model=self.embedding_dim)

    def forward(self, seq, charge):
        meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
        seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
        emb = self.pos_embedding(self.emb(seq_emb))
        meta_enc = self.meta_enc(meta_ohe)

        enc = self.encoder(emb)

        out_rt = self.decoder_RT(enc)
        int_enc = torch.mul(enc, meta_enc)
        out_int = self.decoder_int(int_enc)

        return out_rt.flatten(), out_int

    def forward_rt(self, seq):
        seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
        emb = self.pos_embedding(self.emb(seq_emb))
        enc = self.encoder(emb)
        out_rt = self.decoder_RT(enc)

        return out_rt.flatten()

    def forward_int(self, seq, charge):
        meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
        seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
        emb = self.pos_embedding(self.emb(seq_emb))
        meta_enc = self.meta_enc(meta_ohe)
        enc = self.encoder(emb)
        int_enc = torch.mul(enc, meta_enc)
        out_int = self.decoder_int(int_enc)

        return out_int

class Model_Common_Transformer_TAPE(nn.Module):

    def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=21,
                 regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512,
                 n_head=1, seq_length=25,
                 charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1,
                 decoder_int_num_layer=1, acti='relu', norm=False):
        self.charge_max = charge_max
        self.seq_length = seq_length
        self.nb_aa = nb_aa
        self.charge_frag_max = charge_frag_max
        self.n_head = n_head
        self.embedding_dim = 768
        self.encoder_ff = encoder_ff
        self.encoder_num_layer = encoder_num_layer
        self.decoder_rt_ff = decoder_rt_ff
        self.decoder_rt_num_layer = decoder_rt_num_layer
        self.regressor_layer_size_rt = regressor_layer_size_rt
        self.decoder_int_ff = decoder_int_ff
        self.decoder_int_num_layer = decoder_int_num_layer
        self.regressor_layer_size_int = regressor_layer_size_int
        self.drop_rate = drop_rate
        super(Model_Common_Transformer_TAPE, self).__init__()

        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
                                                                        dim_feedforward=self.encoder_ff,
                                                                        dropout=self.drop_rate, activation=acti,
                                                                        norm_first=norm, batch_first=True),
                                             num_layers=self.encoder_num_layer)

        self.meta_enc = nn.Linear(self.charge_max, self.embedding_dim)

        self.decoder_RT = nn.Sequential(
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
                                                             dim_feedforward=self.decoder_rt_ff,
                                                             dropout=self.drop_rate, activation=acti, norm_first=norm,
                                                             batch_first=True),
                                  num_layers=self.decoder_rt_num_layer),

            nn.Flatten(),
            nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_rt),
            nn.ReLU(),
            nn.Dropout(p=self.drop_rate),
            nn.Linear(self.regressor_layer_size_rt, 1)
        )

        self.decoder_int = nn.Sequential(
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
                                                             dim_feedforward=self.decoder_int_ff,
                                                             dropout=self.drop_rate, activation=acti, norm_first=norm,
                                                             batch_first=True),
                                  num_layers=self.decoder_int_num_layer),
            nn.Flatten(),
            nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_int),
            nn.ReLU(),
            nn.Dropout(p=self.drop_rate),
            nn.Linear(self.regressor_layer_size_int, (self.seq_length - 1) * self.charge_frag_max * 2)
        )

        self.model_TAPE = ProteinBertModel.from_pretrained("./ProteinBert")


    def forward(self, seq, charge):
        meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
        output = self.model_TAPE(seq)

        seq_emb = output[0]
        meta_enc = self.meta_enc(meta_ohe)
        meta_enc = meta_enc.unsqueeze(-1).expand(-1,-1,25)
        meta_enc = torch.permute(meta_enc,(0,2,1))
        enc = self.encoder(seq_emb)
        out_rt = self.decoder_RT(enc)
        int_enc = torch.mul(enc, meta_enc)
        out_int = self.decoder_int(int_enc)

        return out_rt.flatten(), out_int

    def forward_rt(self, seq):
        output = self.model_TAPE(seq)
        seq_emb = output[0]
        enc = self.encoder(seq_emb)
        out_rt = self.decoder_RT(enc)

        return out_rt.flatten()

    def forward_int(self, seq, charge):
        meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
        output = self.model_TAPE(seq)
        seq_emb = output[0]
        meta_enc = self.meta_enc(meta_ohe)
        enc = self.encoder(seq_emb)
        int_enc = torch.mul(enc, meta_enc)
        out_int = self.decoder_int(int_enc)

        return out_int