From 4c0c4e81df47a5ea72583f81bad05f02fb802c57 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 24 Jul 2023 10:50:14 +0200 Subject: [PATCH] Add transformer slot attention --- osrt/layers.py | 15 +++++---------- osrt/model.py | 26 ++++++++++++++++++-------- runs/clevr3d/slot_att/config.yaml | 1 + 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/osrt/layers.py b/osrt/layers.py index 0cfcff1..9eaab5c 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -298,10 +298,8 @@ class SoftPositionEmbed(nn.Module): ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py class TransformerSlotAttention(nn.Module): """ - An extension of Slot Attention using self-attention + An extension of Slot Attention using self-attention inspired from work done in "Visual Concepts Tokenization" from Yang et al. 2022 """ - """def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8, - randomize_initial_slots=False):""" def __init__(self, num_slots=10, depth=6, input_dim=768, slot_dim=1536, hidden_dim=3072, cross_heads=1, self_heads=6, randomize_initial_slots=False): super().__init__() @@ -355,17 +353,14 @@ class TransformerSlotAttention(nn.Module): ############### TODO : adapt this part of code # data = torch.cat((data, enc_pos.reshape(b,-1,enc_pos.shape[-1])), dim = -1) TODO : add a positional encoding here - x0 = repeat(self.latents, 'n d -> b n d', b = b) for i in range(self.depth): cross_attn, cross_ff = self.cs_layers[i] + x = cross_attn(slots, data) + slots # Cross-attention + Residual + slots = cross_ff(x) + x # Feed-forward + Residual - # cross attention only happens once for Perceiver IO - - x = cross_attn(x0, context = data, mask = mask) + x0 - x0 = cross_ff(x) + x - + ## Apply self-attention on input tokens but only before last depth layer if i != self.depth - 1: - self_attn, self_ff = self.layers[i] + self_attn, self_ff = self.sf_layers[i] x_d = self_attn(data) + data data = self_ff(x_d) + x_d diff --git a/osrt/model.py b/osrt/model.py index 4c46fc6..0db162a 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -4,7 +4,7 @@ import numpy as np from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder -from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed +from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed, TransformerSlotAttention import osrt.layers as layers @@ -61,7 +61,7 @@ class SlotAttentionAutoEncoder(nn.Module): Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py """ - def __init__(self, resolution, num_slots, num_iterations): + def __init__(self, resolution, num_slots, num_iterations, cfg): """Builds the Slot Attention-based auto-encoder. Args: @@ -110,12 +110,22 @@ class SlotAttentionAutoEncoder(nn.Module): JaxLinear(64, 64) ) - self.slot_attention = SlotAttention( - num_slots=self.num_slots, - input_dim=64, - slot_dim=64, - hidden_dim=128, - iters=self.num_iterations) + model_type = cfg['model']['model_type'] + if model_type == 'sa': + self.slot_attention = SlotAttention( + num_slots=self.num_slots, + input_dim=64, + slot_dim=64, + hidden_dim=128, + iters=self.num_iterations) + elif model_type == 'tsa': + # We set the same number of inside parameters + self.slot_attention = TransformerSlotAttention( + num_slots=self.num_slots, + input_dim=64, + slot_dim=64, + hidden_dim=128, + iters=self.num_iterations) def forward(self, image): # `image` has shape: [batch_size, num_channels, width, height]. diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml index 0217dd3..a2600fe 100644 --- a/runs/clevr3d/slot_att/config.yaml +++ b/runs/clevr3d/slot_att/config.yaml @@ -3,6 +3,7 @@ data: model: num_slots: 6 iters: 3 + model_type: sa training: num_workers: 2 batch_size: 32 -- GitLab