import math import torch.nn as nn import torch from tape import TAPETokenizer from tape.models.modeling_bert import ProteinBertModel class PermuteLayer(nn.Module): def __init__(self, dims): super().__init__() self.dims = dims def forward(self, x): x = torch.permute(x, self.dims) return x class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 26): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = torch.permute(x, (1, 0, 2)) """ Arguments: x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` """ x = x + self.pe[:x.size(0)] return self.dropout(x) class Model_Common_Transformer(nn.Module): def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=21, regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512, n_head=1, seq_length=25, charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, decoder_int_num_layer=1, acti='relu', norm=False): self.charge_max = charge_max self.seq_length = seq_length self.nb_aa = nb_aa self.charge_frag_max = charge_frag_max self.n_head = n_head self.embedding_dim = embedding_dim self.encoder_ff = encoder_ff self.encoder_num_layer = encoder_num_layer self.decoder_rt_ff = decoder_rt_ff self.decoder_rt_num_layer = decoder_rt_num_layer self.regressor_layer_size_rt = regressor_layer_size_rt self.decoder_int_ff = decoder_int_ff self.decoder_int_num_layer = decoder_int_num_layer self.regressor_layer_size_int = regressor_layer_size_int self.drop_rate = drop_rate super(Model_Common_Transformer, self).__init__() self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head, dim_feedforward=self.encoder_ff, dropout=self.drop_rate, activation=acti, norm_first=norm), num_layers=self.encoder_num_layer) self.meta_enc = nn.Linear(self.charge_max, self.embedding_dim) self.decoder_RT = nn.Sequential( nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head, dim_feedforward=self.decoder_rt_ff, dropout=self.drop_rate, activation=acti, norm_first=norm), num_layers=self.decoder_rt_num_layer), PermuteLayer((1, 0, 2)), nn.Flatten(), nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_rt), nn.ReLU(), nn.Dropout(p=self.drop_rate), nn.Linear(self.regressor_layer_size_rt, 1) ) self.decoder_int = nn.Sequential( nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head, dim_feedforward=self.decoder_int_ff, dropout=self.drop_rate, activation=acti, norm_first=norm), num_layers=self.decoder_int_num_layer), PermuteLayer((1, 0, 2)), nn.Flatten(), nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_int), nn.ReLU(), nn.Dropout(p=self.drop_rate), nn.Linear(self.regressor_layer_size_int, (self.seq_length - 1) * self.charge_frag_max * 2) ) self.emb = nn.Linear(self.nb_aa, self.embedding_dim) self.pos_embedding = PositionalEncoding(max_len=self.seq_length, dropout=self.drop_rate, d_model=self.embedding_dim) def forward(self, seq, charge): meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() emb = self.pos_embedding(self.emb(seq_emb)) meta_enc = self.meta_enc(meta_ohe) enc = self.encoder(emb) out_rt = self.decoder_RT(enc) int_enc = torch.mul(enc, meta_enc) out_int = self.decoder_int(int_enc) return out_rt.flatten(), out_int def forward_rt(self, seq): seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() emb = self.pos_embedding(self.emb(seq_emb)) enc = self.encoder(emb) out_rt = self.decoder_RT(enc) return out_rt.flatten() def forward_int(self, seq, charge): meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() emb = self.pos_embedding(self.emb(seq_emb)) meta_enc = self.meta_enc(meta_ohe) enc = self.encoder(emb) int_enc = torch.mul(enc, meta_enc) out_int = self.decoder_int(int_enc) return out_int class Model_Common_Transformer_TAPE(nn.Module): def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=21, regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512, n_head=1, seq_length=25, charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1, decoder_int_num_layer=1, acti='relu', norm=False): self.charge_max = charge_max self.seq_length = seq_length self.nb_aa = nb_aa self.charge_frag_max = charge_frag_max self.n_head = n_head self.embedding_dim = 768 self.encoder_ff = encoder_ff self.encoder_num_layer = encoder_num_layer self.decoder_rt_ff = decoder_rt_ff self.decoder_rt_num_layer = decoder_rt_num_layer self.regressor_layer_size_rt = regressor_layer_size_rt self.decoder_int_ff = decoder_int_ff self.decoder_int_num_layer = decoder_int_num_layer self.regressor_layer_size_int = regressor_layer_size_int self.drop_rate = drop_rate super(Model_Common_Transformer_TAPE, self).__init__() self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head, dim_feedforward=self.encoder_ff, dropout=self.drop_rate, activation=acti, norm_first=norm, batch_first=True), num_layers=self.encoder_num_layer) self.meta_enc = nn.Linear(self.charge_max, self.embedding_dim) self.decoder_RT = nn.Sequential( nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head, dim_feedforward=self.decoder_rt_ff, dropout=self.drop_rate, activation=acti, norm_first=norm, batch_first=True), num_layers=self.decoder_rt_num_layer), nn.Flatten(), nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_rt), nn.ReLU(), nn.Dropout(p=self.drop_rate), nn.Linear(self.regressor_layer_size_rt, 1) ) self.decoder_int = nn.Sequential( nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head, dim_feedforward=self.decoder_int_ff, dropout=self.drop_rate, activation=acti, norm_first=norm, batch_first=True), num_layers=self.decoder_int_num_layer), nn.Flatten(), nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_int), nn.ReLU(), nn.Dropout(p=self.drop_rate), nn.Linear(self.regressor_layer_size_int, (self.seq_length - 1) * self.charge_frag_max * 2) ) self.model_TAPE = ProteinBertModel.from_pretrained("./ProteinBert") def forward(self, seq, charge): meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() output = self.model_TAPE(seq) seq_emb = output[0] meta_enc = self.meta_enc(meta_ohe) meta_enc = meta_enc.unsqueeze(-1).expand(-1,-1,25) meta_enc = torch.permute(meta_enc,(0,2,1)) enc = self.encoder(seq_emb) out_rt = self.decoder_RT(enc) int_enc = torch.mul(enc, meta_enc) out_int = self.decoder_int(int_enc) return out_rt.flatten(), out_int def forward_rt(self, seq): output = self.model_TAPE(seq) seq_emb = output[0] enc = self.encoder(seq_emb) out_rt = self.decoder_RT(enc) return out_rt.flatten() def forward_int(self, seq, charge): meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() output = self.model_TAPE(seq) seq_emb = output[0] meta_enc = self.meta_enc(meta_ohe) enc = self.encoder(seq_emb) int_enc = torch.mul(enc, meta_enc) out_int = self.decoder_int(int_enc) return out_int