Skip to content
Snippets Groups Projects
layers.py 5.60 KiB
import torch
from torch import nn


class SelectItem(nn.Module):
    def __init__(self, item_index):
        super(SelectItem, self).__init__()
        self._name = 'selectitem'
        self.item_index = item_index

    def forward(self, inputs):
        return inputs[self.item_index]


class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention = self.softmax(scores)
        weighted = torch.bmm(attention, values)
        out = torch.sum(weighted, axis=1)
        return out


class SelfAttention_multi(nn.Module):
    def __init__(self, input_dim, n_head=1):
        if input_dim % n_head != 0:
            raise "Incompatible n_head"
        super(SelfAttention_multi, self).__init__()
        self.input_dim = input_dim // n_head
        self.n_head = n_head
        self.query = []
        self.key = []
        self.value = []
        self.query = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
        self.key = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
        self.value = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        q = []
        k = []
        v = []
        for i in range(self.n_head):
            q.append(self.query[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
            k.append(self.key[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
            v.append(self.value[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))


        queries = torch.cat(q, dim=2)
        keys = torch.cat(k, dim=2)
        values = torch.cat(v, dim=2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention = self.softmax(scores)
        weighted = torch.bmm(attention, values)
        out = torch.sum(weighted, axis=1)
        return out


class SelfAttention_multi_no_sum(nn.Module):
    def __init__(self, input_dim, n_head=1):
        if input_dim % n_head != 0:
            raise "Incompatible n_head"
        super(SelfAttention_multi_no_sum, self).__init__()
        self.input_dim = input_dim // n_head
        self.n_head = n_head
        self.query = []
        self.key = []
        self.value = []
        self.query = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
        self.key = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
        self.value = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        q = []
        k = []
        v = []
        for i in range(self.n_head):
            q.append(self.query[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
            k.append(self.key[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
            v.append(self.value[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))


        queries = torch.cat(q, dim=2)
        keys = torch.cat(k, dim=2)
        values = torch.cat(v, dim=2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention = self.softmax(scores)
        weighted = torch.bmm(attention, values)
        return weighted


class EncoderBlock(nn.Module):

    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0, *args, **kwargs):
        """
        Inputs:
            input_dim - Dimensionality of the input
            num_heads - Number of heads to use in the attention block
            dim_feedforward - Dimensionality of the hidden layer in the MLP
            dropout - Dropout probability to use in the dropout layers
        """

        # Attention layer
        super().__init__(*args, **kwargs)
        self.self_attn = SelfAttention_multi_no_sum(input_dim, num_heads)

        # Two-layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.GELU(),
            nn.Linear(dim_feedforward, input_dim)
        )

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Attention part
        x_n = self.norm1(x)
        attn_out = self.self_attn(x_n)
        x = x + self.dropout(attn_out)

        # MLP part
        x_n = self.norm2(x)
        linear_out = self.linear_net(x_n)
        x = x + self.dropout(linear_out)

        return x


class TransformerEncoder(nn.Module):

    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x):
        for l in self.layers:
            x = l(x)
        return x

    def get_attention_maps(self, x, mask=None):
        attention_maps = []
        for l in self.layers:
            _, attn_map = l.self_attn(x, mask=mask, return_attention=True)
            attention_maps.append(attn_map)
            x = l(x)
        return attention_maps