Skip to content
Snippets Groups Projects
model.py 14.59 KiB
import numpy as np
import torch.nn as nn
import torch

from layers import SelectItem, SelfAttention_multi, SelfAttention, TransformerEncoder


class RT_pred_model(nn.Module):

    def __init__(self, drop_rate):
        super(RT_pred_model, self).__init__()
        self.encoder = nn.Sequential(
            nn.GRU(input_size=8, hidden_size=16, num_layers=2, dropout=drop_rate, bidirectional=True, batch_first=True),
            SelectItem(1),
            nn.Dropout(p=drop_rate)
        )

        self.decoder = nn.Sequential(
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Dropout(p=drop_rate),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Dropout(p=drop_rate),
            nn.Linear(8, 1)
        )

        self.emb = nn.Linear(24, 8)

        self.encoder.float()
        self.decoder.float()
        self.emb.float()

    def forward(self, seq):
        x = torch.nn.functional.one_hot(seq, 24)
        x_emb = self.emb(x.float())
        x_enc = self.encoder(x_emb)
        x_enc = x_enc.swapaxes(0, 1)
        x_enc = torch.flatten(x_enc, start_dim=1)
        x_rt = self.decoder(x_enc)
        x_rt = torch.flatten(x_rt)
        return x_rt


# To remove if multi_sum works
class RT_pred_model_self_attention(nn.Module):

    def __init__(self, drop_rate=0.5, embedding_output_dim=16, nb_aa=23, latent_dropout_rate=0.1,
                 regressor_layer_size=512,
                 recurrent_layers_sizes=(256, 512), ):
        self.drop_rate = drop_rate
        self.regressor_layer_size = regressor_layer_size
        self.latent_dropout_rate = latent_dropout_rate
        self.recurrent_layers_sizes = recurrent_layers_sizes
        self.nb_aa = nb_aa
        self.embedding_output_dim = embedding_output_dim
        super(RT_pred_model_self_attention, self).__init__()
        self.encoder = nn.Sequential(
            nn.GRU(input_size=embedding_output_dim, hidden_size=self.recurrent_layers_sizes[0], num_layers=1,
                   dropout=self.drop_rate,
                   bidirectional=True,
                   batch_first=True),
            SelectItem(0),
            nn.ReLU(),
            nn.Dropout(p=self.drop_rate),
            nn.GRU(input_size=self.recurrent_layers_sizes[0] * 2, hidden_size=self.recurrent_layers_sizes[1],
                   num_layers=1, dropout=self.drop_rate, bidirectional=True,
                   batch_first=True),
            SelectItem(0),
            nn.Dropout(p=drop_rate),
        )

        self.decoder = nn.Sequential(
            SelfAttention_multi(self.recurrent_layers_sizes[1] * 2, 1),
            nn.Linear(self.recurrent_layers_sizes[1] * 2, regressor_layer_size),
            nn.ReLU(),
            nn.Dropout(p=self.latent_dropout_rate),
            nn.Linear(regressor_layer_size, 1)
        )

        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)

        self.encoder.float()
        self.decoder.float()
        self.emb.float()

    def forward(self, seq):
        x = torch.nn.functional.one_hot(seq, self.nb_aa)
        x_emb = self.emb(x.float())
        x_enc = self.encoder(x_emb)
        x_rt = self.decoder(x_enc)
        x_rt = torch.flatten(x_rt)
        return x_rt


class RT_pred_model_self_attention_multi(nn.Module):

    def __init__(self, drop_rate=0.5, embedding_output_dim=16, nb_aa=23, latent_dropout_rate=0.1,
                 regressor_layer_size=512,
                 recurrent_layers_sizes=(256, 512, 512), n_head=8):
        self.drop_rate = drop_rate
        self.n_head = n_head
        self.regressor_layer_size = regressor_layer_size
        self.latent_dropout_rate = latent_dropout_rate
        self.recurrent_layers_sizes = recurrent_layers_sizes
        self.nb_aa = nb_aa
        self.embedding_output_dim = embedding_output_dim
        super(RT_pred_model_self_attention_multi, self).__init__()
        self.encoder = nn.Sequential(
            nn.GRU(input_size=embedding_output_dim, hidden_size=self.recurrent_layers_sizes[0], num_layers=1,
                   bidirectional=True,
                   batch_first=True),
            SelectItem(0),
            nn.ReLU(),
            nn.Dropout(p=self.drop_rate),
            nn.GRU(input_size=self.recurrent_layers_sizes[0] * 2, hidden_size=self.recurrent_layers_sizes[1],
                   num_layers=1, bidirectional=True,
                   batch_first=True),
            SelectItem(0),
            nn.Dropout(p=drop_rate),
        )

        self.attention = nn.MultiheadAttention(self.recurrent_layers_sizes[1] * 2, self.n_head)

        self.decoder = nn.Sequential(
            nn.GRU(input_size=self.recurrent_layers_sizes[1] * 2, hidden_size=self.recurrent_layers_sizes[2],
                   num_layers=1,
                   bidirectional=False,
                   batch_first=True),
            SelectItem(1),
            nn.Linear(self.recurrent_layers_sizes[2], regressor_layer_size),
            nn.ReLU(),
            nn.Dropout(p=self.latent_dropout_rate),
            nn.Linear(regressor_layer_size, 1)
        )

        self.regressor = nn.Linear(self.regressor_layer_size, self.nb_aa)

        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)

        self.select = SelectItem(0)
        self.regressor.float()
        self.attention.float()
        self.encoder.float()
        self.decoder.float()
        self.emb.float()

    def forward(self, seq):
        x = torch.nn.functional.one_hot(seq, self.nb_aa)
        x_emb = self.emb(x.float())
        x_enc = self.encoder(x_emb)
        x_att, _ = self.attention(x_enc, x_enc, x_enc)
        x_rt = self.decoder(x_att)
        x_rt_flat = torch.flatten(x_rt)
        return x_rt_flat


class RT_pred_model_self_attention_multi_sum(nn.Module):

    def __init__(self, drop_rate=0.5, embedding_output_dim=16, nb_aa=23, latent_dropout_rate=0.1,
                 regressor_layer_size=512,
                 recurrent_layers_sizes=(256, 512), n_head=1):
        self.drop_rate = drop_rate
        self.n_head = n_head
        self.regressor_layer_size = regressor_layer_size
        self.latent_dropout_rate = latent_dropout_rate
        self.recurrent_layers_sizes = recurrent_layers_sizes
        self.nb_aa = nb_aa
        self.embedding_output_dim = embedding_output_dim
        super(RT_pred_model_self_attention_multi_sum, self).__init__()
        self.encoder = nn.Sequential(
            nn.GRU(input_size=embedding_output_dim, hidden_size=self.recurrent_layers_sizes[0], num_layers=1,
                   bidirectional=True,
                   batch_first=True),
            SelectItem(0),
            nn.ReLU(),
            nn.Dropout(p=self.drop_rate),
            nn.GRU(input_size=self.recurrent_layers_sizes[0] * 2, hidden_size=self.recurrent_layers_sizes[1],
                   num_layers=1, bidirectional=True,
                   batch_first=True),
            SelectItem(0),
            nn.Dropout(p=drop_rate),
        )

        self.decoder = nn.Sequential(
            SelfAttention_multi(self.recurrent_layers_sizes[1] * 2, self.n_head),
            nn.Linear(self.recurrent_layers_sizes[1] * 2, regressor_layer_size),
            nn.ReLU(),
            nn.Dropout(p=self.latent_dropout_rate),
            nn.Linear(regressor_layer_size, 1)
        )

        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)

        self.encoder.float()
        self.decoder.float()
        self.emb.float()

    def forward(self, seq):
        x = torch.nn.functional.one_hot(seq, self.nb_aa)
        x_emb = self.emb(x.float())
        x_enc = self.encoder(x_emb)
        x_rt = self.decoder(x_enc)
        x_rt = torch.flatten(x_rt)
        return x_rt


class RT_pred_model_transformer(nn.Module):

    def __init__(self, drop_rate=0.5, embedding_output_dim=128, nb_aa=23, latent_dropout_rate=0.1,
                 regressor_layer_size=512, n_head=1):
        self.drop_rate = drop_rate
        self.n_head = n_head
        self.regressor_layer_size = regressor_layer_size
        self.latent_dropout_rate = latent_dropout_rate
        self.nb_aa = nb_aa
        self.embedding_output_dim = embedding_output_dim
        super(RT_pred_model_transformer, self).__init__()
        self.encoder = nn.Sequential(
            TransformerEncoder(1, input_dim=embedding_output_dim, num_heads=self.n_head, dim_feedforward=512,
                               dropout=self.drop_rate)
        )

        self.decoder = nn.Sequential(
            TransformerEncoder(1, input_dim=embedding_output_dim, num_heads=self.n_head, dim_feedforward=512,
                               dropout=self.drop_rate),
            nn.Flatten(),
            nn.Linear(embedding_output_dim * 30, self.regressor_layer_size),
            nn.ReLU(),
            nn.Dropout(p=self.latent_dropout_rate),
            nn.Linear(self.regressor_layer_size, 1)
        )

        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)
        self.pos_embedding = nn.Linear(30, self.embedding_output_dim)

        self.pos_embedding.float()
        self.encoder.float()
        self.decoder.float()
        self.emb.float()

    def forward(self, seq):
        indices = torch.tensor([i for i in range(30)])
        indice_ohe = torch.nn.functional.one_hot(indices, 30)
        x_ind = self.pos_embedding(indice_ohe.float())
        x = torch.nn.functional.one_hot(seq, self.nb_aa)
        x_emb = self.emb(x.float())
        x_enc = self.encoder(x_emb + x_ind)
        x_rt = self.decoder(x_enc)
        x_rt = torch.flatten(x_rt)
        return x_rt


class RT_pred_model_self_attention_pretext(nn.Module):

    def __init__(self, drop_rate=0.5, embedding_output_dim=16, nb_aa=23, latent_dropout_rate=0.1,
                 regressor_layer_size=512,
                 recurrent_layers_sizes=(256, 512), ):
        self.drop_rate = drop_rate
        self.regressor_layer_size = regressor_layer_size
        self.latent_dropout_rate = latent_dropout_rate
        self.recurrent_layers_sizes = recurrent_layers_sizes
        self.nb_aa = nb_aa
        self.embedding_output_dim = embedding_output_dim
        super(RT_pred_model_self_attention_pretext, self).__init__()
        self.encoder = nn.Sequential(
            nn.GRU(input_size=embedding_output_dim, hidden_size=self.recurrent_layers_sizes[0], num_layers=1,
                   bidirectional=True,
                   batch_first=True),
            SelectItem(0),
            nn.ReLU(),
            nn.Dropout(p=self.drop_rate),
            nn.GRU(input_size=self.recurrent_layers_sizes[0] * 2, hidden_size=self.recurrent_layers_sizes[1],
                   num_layers=1, bidirectional=True,
                   batch_first=True),
            SelectItem(0),
            nn.Dropout(p=drop_rate),
        )

        self.decoder_rec = nn.Sequential(
            nn.GRU(input_size=self.recurrent_layers_sizes[1] * 2, hidden_size=self.regressor_layer_size, num_layers=1,
                   bidirectional=False,
                   batch_first=True),
            SelectItem(0),
            nn.Dropout(p=drop_rate),
        )

        self.attention = nn.MultiheadAttention(self.recurrent_layers_sizes[1] * 2, 1)

        self.decoder = nn.Sequential(
            SelfAttention(self.recurrent_layers_sizes[1] * 2),
            nn.Linear(self.recurrent_layers_sizes[1] * 2, regressor_layer_size),
            nn.ReLU(),
            nn.Dropout(p=self.latent_dropout_rate),
            nn.Linear(regressor_layer_size, 1)
        )

        self.regressor = nn.Linear(self.regressor_layer_size, self.nb_aa)

        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)

        self.decoder_rec.float()
        self.regressor.float()
        self.attention.float()
        self.encoder.float()
        self.decoder.float()
        self.emb.float()

    def forward(self, seq):
        x = torch.nn.functional.one_hot(seq, self.nb_aa)
        x_emb = self.emb(x.float())
        x_enc = self.encoder(x_emb)
        x_rt = self.decoder(x_enc)
        enc_att, _ = self.attention(x_enc, x_enc, x_enc)
        dec_att = self.decoder_rec(enc_att)
        seq_rec = self.regressor(dec_att)
        x_rt = torch.flatten(x_rt)
        return x_rt, seq_rec


class Intensity_pred_model_multi_head(nn.Module):

    def __init__(self, drop_rate=0.5, embedding_output_dim=16, nb_aa=22, latent_dropout_rate=0.1,
                 regressor_layer_size=512,
                 recurrent_layers_sizes=(256, 512), ):
        self.drop_rate = drop_rate
        self.regressor_layer_size = regressor_layer_size
        self.latent_dropout_rate = latent_dropout_rate
        self.recurrent_layers_sizes = recurrent_layers_sizes
        self.nb_aa = nb_aa
        self.embedding_output_dim = embedding_output_dim
        super(Intensity_pred_model_multi_head, self).__init__()
        self.seq_encoder = nn.Sequential(
            nn.GRU(input_size=self.embedding_output_dim, hidden_size=self.recurrent_layers_sizes[0], num_layers=1,
                   bidirectional=True, batch_first=True),
            SelectItem(0),
            nn.ReLU(),
            nn.Dropout(p=drop_rate),
            nn.GRU(input_size=self.recurrent_layers_sizes[0] * 2, hidden_size=self.recurrent_layers_sizes[1],
                   num_layers=1, bidirectional=True,
                   batch_first=True),
            SelectItem(0),
            nn.Dropout(p=drop_rate),
        )

        self.meta_enc = nn.Sequential(nn.Linear(7, self.recurrent_layers_sizes[1] * 2))

        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)

        self.attention = nn.MultiheadAttention(self.recurrent_layers_sizes[1] * 2, 1)

        self.decoder = nn.Sequential(
            nn.GRU(input_size=self.recurrent_layers_sizes[1] * 2, hidden_size=self.regressor_layer_size, num_layers=1,
                   bidirectional=False, batch_first=True),
            SelectItem(0),
            nn.Dropout(p=drop_rate),
        )

        self.regressor = nn.Linear(self.regressor_layer_size, 1)

        # intensity range from 0 to 1 (-1 mean impossible)
        self.meta_enc.float()
        self.seq_encoder.float()
        self.decoder.float()
        self.emb.float()
        self.attention.float()
        self.regressor.float()

    def forward(self, seq, energy, charge):
        x = torch.nn.functional.one_hot(seq.long(), self.nb_aa)
        x_emb = self.emb(x.float())
        out_1 = self.seq_encoder(x_emb)
        weight_out, _ = self.attention(out_1, out_1, out_1)
        # metadata encoder
        out_2 = self.meta_enc(torch.concat([charge, energy], 1))
        out_2 = out_2.repeat(30, 1, 1)
        out_2 = out_2.transpose(0, 1)
        fusion_encoding = torch.mul(out_2, weight_out)
        fusion_encoding_rep = fusion_encoding.repeat(1, 6, 1)
        out = self.decoder(fusion_encoding_rep)
        intensity = self.regressor(out)
        intensity = torch.flatten(intensity, start_dim=1)
        intensity = intensity[:, :174]

        return intensity