-
Schneider Leo authored2ae9dc2d
layers.py 5.60 KiB
import torch
from torch import nn
class SelectItem(nn.Module):
def __init__(self, item_index):
super(SelectItem, self).__init__()
self._name = 'selectitem'
self.item_index = item_index
def forward(self, inputs):
return inputs[self.item_index]
class SelfAttention(nn.Module):
def __init__(self, input_dim):
super(SelfAttention, self).__init__()
self.input_dim = input_dim
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
queries = self.query(x)
keys = self.key(x)
values = self.value(x)
scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
attention = self.softmax(scores)
weighted = torch.bmm(attention, values)
out = torch.sum(weighted, axis=1)
return out
class SelfAttention_multi(nn.Module):
def __init__(self, input_dim, n_head=1):
if input_dim % n_head != 0:
raise "Incompatible n_head"
super(SelfAttention_multi, self).__init__()
self.input_dim = input_dim // n_head
self.n_head = n_head
self.query = []
self.key = []
self.value = []
self.query = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
self.key = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
self.value = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
q = []
k = []
v = []
for i in range(self.n_head):
q.append(self.query[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
k.append(self.key[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
v.append(self.value[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
queries = torch.cat(q, dim=2)
keys = torch.cat(k, dim=2)
values = torch.cat(v, dim=2)
scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
attention = self.softmax(scores)
weighted = torch.bmm(attention, values)
out = torch.sum(weighted, axis=1)
return out
class SelfAttention_multi_no_sum(nn.Module):
def __init__(self, input_dim, n_head=1):
if input_dim % n_head != 0:
raise "Incompatible n_head"
super(SelfAttention_multi_no_sum, self).__init__()
self.input_dim = input_dim // n_head
self.n_head = n_head
self.query = []
self.key = []
self.value = []
self.query = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
self.key = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
self.value = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
q = []
k = []
v = []
for i in range(self.n_head):
q.append(self.query[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
k.append(self.key[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
v.append(self.value[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
queries = torch.cat(q, dim=2)
keys = torch.cat(k, dim=2)
values = torch.cat(v, dim=2)
scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
attention = self.softmax(scores)
weighted = torch.bmm(attention, values)
return weighted
class EncoderBlock(nn.Module):
def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0, *args, **kwargs):
"""
Inputs:
input_dim - Dimensionality of the input
num_heads - Number of heads to use in the attention block
dim_feedforward - Dimensionality of the hidden layer in the MLP
dropout - Dropout probability to use in the dropout layers
"""
# Attention layer
super().__init__(*args, **kwargs)
self.self_attn = SelfAttention_multi_no_sum(input_dim, num_heads)
# Two-layer MLP
self.linear_net = nn.Sequential(
nn.Linear(input_dim, dim_feedforward),
nn.Dropout(dropout),
nn.GELU(),
nn.Linear(dim_feedforward, input_dim)
)
# Layers to apply in between the main layers
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Attention part
x_n = self.norm1(x)
attn_out = self.self_attn(x_n)
x = x + self.dropout(attn_out)
# MLP part
x_n = self.norm2(x)
linear_out = self.linear_net(x_n)
x = x + self.dropout(linear_out)
return x
class TransformerEncoder(nn.Module):
def __init__(self, num_layers, **block_args):
super().__init__()
self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])
def forward(self, x):
for l in self.layers:
x = l(x)
return x
def get_attention_maps(self, x, mask=None):
attention_maps = []
for l in self.layers:
_, attn_map = l.self_attn(x, mask=mask, return_attention=True)
attention_maps.append(attn_map)
x = l(x)
return attention_maps