Skip to content
Snippets Groups Projects
Commit 4c0c4e81 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add transformer slot attention

parent 1915d2fe
No related branches found
No related tags found
No related merge requests found
...@@ -298,10 +298,8 @@ class SoftPositionEmbed(nn.Module): ...@@ -298,10 +298,8 @@ class SoftPositionEmbed(nn.Module):
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
class TransformerSlotAttention(nn.Module): class TransformerSlotAttention(nn.Module):
""" """
An extension of Slot Attention using self-attention An extension of Slot Attention using self-attention inspired from work done in "Visual Concepts Tokenization" from Yang et al. 2022
""" """
"""def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8,
randomize_initial_slots=False):"""
def __init__(self, num_slots=10, depth=6, input_dim=768, slot_dim=1536, hidden_dim=3072, cross_heads=1, self_heads=6, def __init__(self, num_slots=10, depth=6, input_dim=768, slot_dim=1536, hidden_dim=3072, cross_heads=1, self_heads=6,
randomize_initial_slots=False): randomize_initial_slots=False):
super().__init__() super().__init__()
...@@ -355,17 +353,14 @@ class TransformerSlotAttention(nn.Module): ...@@ -355,17 +353,14 @@ class TransformerSlotAttention(nn.Module):
############### TODO : adapt this part of code ############### TODO : adapt this part of code
# data = torch.cat((data, enc_pos.reshape(b,-1,enc_pos.shape[-1])), dim = -1) TODO : add a positional encoding here # data = torch.cat((data, enc_pos.reshape(b,-1,enc_pos.shape[-1])), dim = -1) TODO : add a positional encoding here
x0 = repeat(self.latents, 'n d -> b n d', b = b)
for i in range(self.depth): for i in range(self.depth):
cross_attn, cross_ff = self.cs_layers[i] cross_attn, cross_ff = self.cs_layers[i]
x = cross_attn(slots, data) + slots # Cross-attention + Residual
slots = cross_ff(x) + x # Feed-forward + Residual
# cross attention only happens once for Perceiver IO ## Apply self-attention on input tokens but only before last depth layer
x = cross_attn(x0, context = data, mask = mask) + x0
x0 = cross_ff(x) + x
if i != self.depth - 1: if i != self.depth - 1:
self_attn, self_ff = self.layers[i] self_attn, self_ff = self.sf_layers[i]
x_d = self_attn(data) + data x_d = self_attn(data) + data
data = self_ff(x_d) + x_d data = self_ff(x_d) + x_d
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed, TransformerSlotAttention
import osrt.layers as layers import osrt.layers as layers
...@@ -61,7 +61,7 @@ class SlotAttentionAutoEncoder(nn.Module): ...@@ -61,7 +61,7 @@ class SlotAttentionAutoEncoder(nn.Module):
Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
""" """
def __init__(self, resolution, num_slots, num_iterations): def __init__(self, resolution, num_slots, num_iterations, cfg):
"""Builds the Slot Attention-based auto-encoder. """Builds the Slot Attention-based auto-encoder.
Args: Args:
...@@ -110,12 +110,22 @@ class SlotAttentionAutoEncoder(nn.Module): ...@@ -110,12 +110,22 @@ class SlotAttentionAutoEncoder(nn.Module):
JaxLinear(64, 64) JaxLinear(64, 64)
) )
self.slot_attention = SlotAttention( model_type = cfg['model']['model_type']
num_slots=self.num_slots, if model_type == 'sa':
input_dim=64, self.slot_attention = SlotAttention(
slot_dim=64, num_slots=self.num_slots,
hidden_dim=128, input_dim=64,
iters=self.num_iterations) slot_dim=64,
hidden_dim=128,
iters=self.num_iterations)
elif model_type == 'tsa':
# We set the same number of inside parameters
self.slot_attention = TransformerSlotAttention(
num_slots=self.num_slots,
input_dim=64,
slot_dim=64,
hidden_dim=128,
iters=self.num_iterations)
def forward(self, image): def forward(self, image):
# `image` has shape: [batch_size, num_channels, width, height]. # `image` has shape: [batch_size, num_channels, width, height].
......
...@@ -3,6 +3,7 @@ data: ...@@ -3,6 +3,7 @@ data:
model: model:
num_slots: 6 num_slots: 6
iters: 3 iters: 3
model_type: sa
training: training:
num_workers: 2 num_workers: 2
batch_size: 32 batch_size: 32
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment