-
Léo Schneider authoredfaee664f
model_custom.py 10.12 KiB
import math
import torch.nn as nn
import torch
from tape import TAPETokenizer
from tape.models.modeling_bert import ProteinBertModel
class PermuteLayer(nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims
def forward(self, x):
x = torch.permute(x, self.dims)
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 26):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = torch.permute(x, (1, 0, 2))
"""
Arguments:
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class Model_Common_Transformer(nn.Module):
def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=21,
regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512,
n_head=1, seq_length=25,
charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1,
decoder_int_num_layer=1, acti='relu', norm=False):
self.charge_max = charge_max
self.seq_length = seq_length
self.nb_aa = nb_aa
self.charge_frag_max = charge_frag_max
self.n_head = n_head
self.embedding_dim = embedding_dim
self.encoder_ff = encoder_ff
self.encoder_num_layer = encoder_num_layer
self.decoder_rt_ff = decoder_rt_ff
self.decoder_rt_num_layer = decoder_rt_num_layer
self.regressor_layer_size_rt = regressor_layer_size_rt
self.decoder_int_ff = decoder_int_ff
self.decoder_int_num_layer = decoder_int_num_layer
self.regressor_layer_size_int = regressor_layer_size_int
self.drop_rate = drop_rate
super(Model_Common_Transformer, self).__init__()
self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
dim_feedforward=self.encoder_ff,
dropout=self.drop_rate, activation=acti,
norm_first=norm),
num_layers=self.encoder_num_layer)
self.meta_enc = nn.Linear(self.charge_max, self.embedding_dim)
self.decoder_RT = nn.Sequential(
nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
dim_feedforward=self.decoder_rt_ff,
dropout=self.drop_rate, activation=acti, norm_first=norm),
num_layers=self.decoder_rt_num_layer),
PermuteLayer((1, 0, 2)),
nn.Flatten(),
nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_rt),
nn.ReLU(),
nn.Dropout(p=self.drop_rate),
nn.Linear(self.regressor_layer_size_rt, 1)
)
self.decoder_int = nn.Sequential(
nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
dim_feedforward=self.decoder_int_ff,
dropout=self.drop_rate, activation=acti, norm_first=norm),
num_layers=self.decoder_int_num_layer),
PermuteLayer((1, 0, 2)),
nn.Flatten(),
nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_int),
nn.ReLU(),
nn.Dropout(p=self.drop_rate),
nn.Linear(self.regressor_layer_size_int, (self.seq_length - 1) * self.charge_frag_max * 2)
)
self.emb = nn.Linear(self.nb_aa, self.embedding_dim)
self.pos_embedding = PositionalEncoding(max_len=self.seq_length, dropout=self.drop_rate,
d_model=self.embedding_dim)
def forward(self, seq, charge):
meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
emb = self.pos_embedding(self.emb(seq_emb))
meta_enc = self.meta_enc(meta_ohe)
enc = self.encoder(emb)
out_rt = self.decoder_RT(enc)
int_enc = torch.mul(enc, meta_enc)
out_int = self.decoder_int(int_enc)
return out_rt.flatten(), out_int
def forward_rt(self, seq):
seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
emb = self.pos_embedding(self.emb(seq_emb))
enc = self.encoder(emb)
out_rt = self.decoder_RT(enc)
return out_rt.flatten()
def forward_int(self, seq, charge):
meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
emb = self.pos_embedding(self.emb(seq_emb))
meta_enc = self.meta_enc(meta_ohe)
enc = self.encoder(emb)
int_enc = torch.mul(enc, meta_enc)
out_int = self.decoder_int(int_enc)
return out_int
class Model_Common_Transformer_TAPE(nn.Module):
def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=21,
regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512,
n_head=1, seq_length=25,
charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1,
decoder_int_num_layer=1, acti='relu', norm=False):
self.charge_max = charge_max
self.seq_length = seq_length
self.nb_aa = nb_aa
self.charge_frag_max = charge_frag_max
self.n_head = n_head
self.embedding_dim = 768
self.encoder_ff = encoder_ff
self.encoder_num_layer = encoder_num_layer
self.decoder_rt_ff = decoder_rt_ff
self.decoder_rt_num_layer = decoder_rt_num_layer
self.regressor_layer_size_rt = regressor_layer_size_rt
self.decoder_int_ff = decoder_int_ff
self.decoder_int_num_layer = decoder_int_num_layer
self.regressor_layer_size_int = regressor_layer_size_int
self.drop_rate = drop_rate
super(Model_Common_Transformer_TAPE, self).__init__()
self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
dim_feedforward=self.encoder_ff,
dropout=self.drop_rate, activation=acti,
norm_first=norm, batch_first=True),
num_layers=self.encoder_num_layer)
self.meta_enc = nn.Linear(self.charge_max, self.embedding_dim)
self.decoder_RT = nn.Sequential(
nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
dim_feedforward=self.decoder_rt_ff,
dropout=self.drop_rate, activation=acti, norm_first=norm,
batch_first=True),
num_layers=self.decoder_rt_num_layer),
nn.Flatten(),
nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_rt),
nn.ReLU(),
nn.Dropout(p=self.drop_rate),
nn.Linear(self.regressor_layer_size_rt, 1)
)
self.decoder_int = nn.Sequential(
nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
dim_feedforward=self.decoder_int_ff,
dropout=self.drop_rate, activation=acti, norm_first=norm,
batch_first=True),
num_layers=self.decoder_int_num_layer),
nn.Flatten(),
nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_int),
nn.ReLU(),
nn.Dropout(p=self.drop_rate),
nn.Linear(self.regressor_layer_size_int, (self.seq_length - 1) * self.charge_frag_max * 2)
)
self.model_TAPE = ProteinBertModel.from_pretrained("./ProteinBert")
def forward(self, seq, charge):
meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
output = self.model_TAPE(seq)
seq_emb = output[0]
meta_enc = self.meta_enc(meta_ohe)
meta_enc = meta_enc.unsqueeze(-1).expand(-1,-1,25)
meta_enc = torch.permute(meta_enc,(0,2,1))
enc = self.encoder(seq_emb)
out_rt = self.decoder_RT(enc)
int_enc = torch.mul(enc, meta_enc)
out_int = self.decoder_int(int_enc)
return out_rt.flatten(), out_int
def forward_rt(self, seq):
output = self.model_TAPE(seq)
seq_emb = output[0]
enc = self.encoder(seq_emb)
out_rt = self.decoder_RT(enc)
return out_rt.flatten()
def forward_int(self, seq, charge):
meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
output = self.model_TAPE(seq)
seq_emb = output[0]
meta_enc = self.meta_enc(meta_ohe)
enc = self.encoder(seq_emb)
int_enc = torch.mul(enc, meta_enc)
out_int = self.decoder_int(int_enc)
return out_int