From 4c0c4e81df47a5ea72583f81bad05f02fb802c57 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Mon, 24 Jul 2023 10:50:14 +0200
Subject: [PATCH] Add transformer slot attention

---
 osrt/layers.py                    | 15 +++++----------
 osrt/model.py                     | 26 ++++++++++++++++++--------
 runs/clevr3d/slot_att/config.yaml |  1 +
 3 files changed, 24 insertions(+), 18 deletions(-)

diff --git a/osrt/layers.py b/osrt/layers.py
index 0cfcff1..9eaab5c 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -298,10 +298,8 @@ class SoftPositionEmbed(nn.Module):
 ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
 class TransformerSlotAttention(nn.Module):
     """
-    An extension of Slot Attention using self-attention
+    An extension of Slot Attention using self-attention inspired from work done in "Visual Concepts Tokenization" from Yang et al. 2022
     """
-    """def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8,
-                 randomize_initial_slots=False):"""
     def __init__(self, num_slots=10, depth=6, input_dim=768, slot_dim=1536, hidden_dim=3072, cross_heads=1, self_heads=6,
                  randomize_initial_slots=False):
         super().__init__()
@@ -355,17 +353,14 @@ class TransformerSlotAttention(nn.Module):
         ############### TODO : adapt this part of code    
         # data = torch.cat((data, enc_pos.reshape(b,-1,enc_pos.shape[-1])), dim = -1) TODO : add a positional encoding here
 
-        x0 = repeat(self.latents, 'n d -> b n d', b = b)
         for i in range(self.depth):
             cross_attn, cross_ff = self.cs_layers[i]
+            x = cross_attn(slots, data) + slots # Cross-attention + Residual
+            slots = cross_ff(x) + x # Feed-forward + Residual
 
-            # cross attention only happens once for Perceiver IO
-
-            x = cross_attn(x0, context = data, mask = mask) + x0
-            x0 = cross_ff(x) + x
-
+            ## Apply self-attention on input tokens but only before last depth layer
             if i != self.depth - 1:
-                self_attn, self_ff = self.layers[i]
+                self_attn, self_ff = self.sf_layers[i]
                 x_d = self_attn(data) + data
                 data = self_ff(x_d) + x_d
 
diff --git a/osrt/model.py b/osrt/model.py
index 4c46fc6..0db162a 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -4,7 +4,7 @@ import numpy as np
 
 from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
 from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
-from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed
+from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed, TransformerSlotAttention
 
 import osrt.layers as layers
 
@@ -61,7 +61,7 @@ class SlotAttentionAutoEncoder(nn.Module):
     Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
     """
 
-    def __init__(self, resolution, num_slots, num_iterations):
+    def __init__(self, resolution, num_slots, num_iterations, cfg):
         """Builds the Slot Attention-based auto-encoder.
 
         Args:
@@ -110,12 +110,22 @@ class SlotAttentionAutoEncoder(nn.Module):
             JaxLinear(64, 64)
         )
 
-        self.slot_attention = SlotAttention(
-            num_slots=self.num_slots,
-            input_dim=64,
-            slot_dim=64,
-            hidden_dim=128,
-            iters=self.num_iterations)
+        model_type = cfg['model']['model_type']
+        if model_type == 'sa':
+            self.slot_attention = SlotAttention(
+                num_slots=self.num_slots,
+                input_dim=64,
+                slot_dim=64,
+                hidden_dim=128,
+                iters=self.num_iterations)
+        elif model_type == 'tsa':
+            # We set the same number of inside parameters
+            self.slot_attention = TransformerSlotAttention(
+                num_slots=self.num_slots,
+                input_dim=64,
+                slot_dim=64,
+                hidden_dim=128,
+                iters=self.num_iterations)
 
     def forward(self, image):
         # `image` has shape: [batch_size, num_channels, width, height].
diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml
index 0217dd3..a2600fe 100644
--- a/runs/clevr3d/slot_att/config.yaml
+++ b/runs/clevr3d/slot_att/config.yaml
@@ -3,6 +3,7 @@ data:
 model:
   num_slots: 6
   iters: 3
+  model_type: sa
 training:
   num_workers: 2 
   batch_size: 32 
-- 
GitLab