import torch from torch import nn class SelectItem(nn.Module): def __init__(self, item_index): super(SelectItem, self).__init__() self._name = 'selectitem' self.item_index = item_index def forward(self, inputs): return inputs[self.item_index] class SelfAttention(nn.Module): def __init__(self, input_dim): super(SelfAttention, self).__init__() self.input_dim = input_dim self.query = nn.Linear(input_dim, input_dim) self.key = nn.Linear(input_dim, input_dim) self.value = nn.Linear(input_dim, input_dim) self.softmax = nn.Softmax(dim=2) def forward(self, x): queries = self.query(x) keys = self.key(x) values = self.value(x) scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5) attention = self.softmax(scores) weighted = torch.bmm(attention, values) out = torch.sum(weighted, axis=1) return out class SelfAttention_multi(nn.Module): def __init__(self, input_dim, n_head=1): if input_dim % n_head != 0: raise "Incompatible n_head" super(SelfAttention_multi, self).__init__() self.input_dim = input_dim // n_head self.n_head = n_head self.query = [] self.key = [] self.value = [] self.query = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)]) self.key = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)]) self.value = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)]) self.softmax = nn.Softmax(dim=2) def forward(self, x): q = [] k = [] v = [] for i in range(self.n_head): q.append(self.query[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)])) k.append(self.key[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)])) v.append(self.value[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)])) queries = torch.cat(q, dim=2) keys = torch.cat(k, dim=2) values = torch.cat(v, dim=2) scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5) attention = self.softmax(scores) weighted = torch.bmm(attention, values) out = torch.sum(weighted, axis=1) return out class SelfAttention_multi_no_sum(nn.Module): def __init__(self, input_dim, n_head=1): if input_dim % n_head != 0: raise "Incompatible n_head" super(SelfAttention_multi_no_sum, self).__init__() self.input_dim = input_dim // n_head self.n_head = n_head self.query = [] self.key = [] self.value = [] self.query = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)]) self.key = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)]) self.value = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)]) self.softmax = nn.Softmax(dim=2) def forward(self, x): q = [] k = [] v = [] for i in range(self.n_head): q.append(self.query[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)])) k.append(self.key[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)])) v.append(self.value[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)])) queries = torch.cat(q, dim=2) keys = torch.cat(k, dim=2) values = torch.cat(v, dim=2) scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5) attention = self.softmax(scores) weighted = torch.bmm(attention, values) return weighted class EncoderBlock(nn.Module): def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0, *args, **kwargs): """ Inputs: input_dim - Dimensionality of the input num_heads - Number of heads to use in the attention block dim_feedforward - Dimensionality of the hidden layer in the MLP dropout - Dropout probability to use in the dropout layers """ # Attention layer super().__init__(*args, **kwargs) self.self_attn = SelfAttention_multi_no_sum(input_dim, num_heads) # Two-layer MLP self.linear_net = nn.Sequential( nn.Linear(input_dim, dim_feedforward), nn.Dropout(dropout), nn.GELU(), nn.Linear(dim_feedforward, input_dim) ) # Layers to apply in between the main layers self.norm1 = nn.LayerNorm(input_dim) self.norm2 = nn.LayerNorm(input_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): # Attention part x_n = self.norm1(x) attn_out = self.self_attn(x_n) x = x + self.dropout(attn_out) # MLP part x_n = self.norm2(x) linear_out = self.linear_net(x_n) x = x + self.dropout(linear_out) return x class TransformerEncoder(nn.Module): def __init__(self, num_layers, **block_args): super().__init__() self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)]) def forward(self, x): for l in self.layers: x = l(x) return x def get_attention_maps(self, x, mask=None): attention_maps = [] for l in self.layers: _, attn_map = l.self_attn(x, mask=mask, return_attention=True) attention_maps.append(attn_map) x = l(x) return attention_maps