diff --git a/.visualisation_0.png b/.visualisation_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f0184272777793b0d4181822f23ecf15a2b2a1d
Binary files /dev/null and b/.visualisation_0.png differ
diff --git a/evaluate_sa.py b/evaluate_sa.py
index af75b4bb89159700ede32361e378238a24693ac0..e5e06eb7ec4990daa29ba21aead49de35b00f270 100644
--- a/evaluate_sa.py
+++ b/evaluate_sa.py
@@ -38,7 +38,7 @@ def main():
     resolution = (128, 128)
     
     #### Create datasets
-    test_dataset = data.get_dataset('val', cfg['data'])
+    test_dataset = data.get_dataset('test', cfg['data'])
     test_dataloader = DataLoader(
         test_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
         shuffle=True, worker_init_fn=data.worker_init_fn)
diff --git a/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0 b/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0
new file mode 100644
index 0000000000000000000000000000000000000000..9e9c05c6e308e39ca9c9298dc2b9b8a32a4b516b
Binary files /dev/null and b/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0 differ
diff --git a/lightning_logs/version_2/hparams.yaml b/lightning_logs/version_2/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93
--- /dev/null
+++ b/lightning_logs/version_2/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0 b/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0
new file mode 100644
index 0000000000000000000000000000000000000000..a7cc5deb3a254b64b177f15573a663f3ad9d0de6
Binary files /dev/null and b/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0 differ
diff --git a/lightning_logs/version_3/hparams.yaml b/lightning_logs/version_3/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93
--- /dev/null
+++ b/lightning_logs/version_3/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0 b/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0
new file mode 100644
index 0000000000000000000000000000000000000000..65a1f5c0cfc4c7d12dea17678f70732aaaf2731c
Binary files /dev/null and b/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0 differ
diff --git a/lightning_logs/version_4/hparams.yaml b/lightning_logs/version_4/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93
--- /dev/null
+++ b/lightning_logs/version_4/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0 b/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0
new file mode 100644
index 0000000000000000000000000000000000000000..73d3b9d0a715b4f3a4b30267a21646af46b3fabe
Binary files /dev/null and b/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0 differ
diff --git a/lightning_logs/version_5/hparams.yaml b/lightning_logs/version_5/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93
--- /dev/null
+++ b/lightning_logs/version_5/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/osrt/layers.py b/osrt/layers.py
index 24c363f866c6b003ff4c1e297b0e84a2bbd90258..9b7c55543c99c37cf3361a3934e1d23acf559437 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -5,6 +5,8 @@ import numpy as np
 
 import math
 from einops import rearrange, repeat
+from einops.layers.torch import Rearrange
+
 import torch.nn.functional as F
 
 
@@ -303,6 +305,66 @@ def fourier_encode(x, max_freq, num_bands = 4):
     x = torch.cat((x, orig_x), dim = -1)
     return x
 
+
+class AutoEncoder(nn.Module):
+    def __init__(self, patch_size, image_size, emb_dim):
+        super(self).__init__()
+        self.patchify = nn.Conv2d(3, emb_dim, patch_size, patch_size)
+        self.head = nn.Linear(emb_dim, 3 * patch_size ** 2)
+        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)
+
+    def encode(self, img):
+        return self.patchify(img)
+
+    def decode(self, feature):
+        feature = feature.reshape(feature.shape[0],feature.shape[1],-1).permute(1,0,2)
+        return self.patch2img(self.head(feature))
+    
+class Encoder(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.encoder_cnn = nn.Sequential(
+            nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
+            nn.Conv2d(64, 64, kernel_size=5, padding=2, stride=2), nn.ReLU(inplace=True), # Added a stride to reduce memory impact
+            nn.Conv2d(64, 64, kernel_size=5, padding=2, stride=2), nn.ReLU(inplace=True), # Added a stride to reduce memory impact
+            nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True)
+        )
+        self.encoder_pos = PositionEmbeddingImplicit(64)
+
+        self.layer_norm = nn.LayerNorm(64)
+        self.mlp = nn.Sequential(
+            nn.Linear(64, 64),
+            nn.ReLU(inplace=True),
+            nn.Linear(64, 64)
+        )
+
+    def forward(self, x):
+        x = self.encoder_cnn(x).movedim(1, -1)
+        x = self.encoder_pos(x)
+        x = self.mlp(self.layer_norm(x))
+        return x
+
+
+class Decoder(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.decoder_initial_size = (8, 8)
+        self.decoder_cnn = nn.Sequential(
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 4, kernel_size=3)
+        )
+
+        self.decoder_pos = PositionEmbeddingImplicit(64)
+
+    def forward(self, x):
+        x = self.decoder_pos(x)
+        x = self.decoder_cnn(x.movedim(-1, 1))
+        return x
+
 ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
 class TransformerSlotAttention(nn.Module):
     """
diff --git a/osrt/model.py b/osrt/model.py
index d176d9abd4d115b3c6268625988327e9ada1f67b..932abd9224e956af29d9eb1d05b426b77461e465 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -9,7 +9,7 @@ import numpy as np
 
 from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
 from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
-from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention
+from osrt.layers import SlotAttention, TransformerSlotAttention, Encoder, Decoder
 import osrt.layers as layers
 from osrt.utils.common import mse2psnr
 
@@ -66,32 +66,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
 
         self.criterion = nn.MSELoss()
 
-        self.encoder_cnn = nn.Sequential(
-            nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
-            nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
-            nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
-            nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True)
-        )
-
-        self.decoder_initial_size = (8, 8)
-        self.decoder_cnn = nn.Sequential(
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
-            nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
-            nn.ConvTranspose2d(64, 4, kernel_size=3)
-        )
-
-        self.encoder_pos = PositionEmbeddingImplicit(64)
-        self.decoder_pos = PositionEmbeddingImplicit(64)
-
-        self.layer_norm = nn.LayerNorm(64)
-        self.mlp = nn.Sequential(
-            nn.Linear(64, 64),
-            nn.ReLU(inplace=True),
-            nn.Linear(64, 64)
-        )
+        self.encoder = Encoder()
+        self.decoder = Decoder()
 
         model_type = cfg['model']['model_type']
         if model_type == 'sa':
@@ -111,15 +87,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
                 depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model
 
     def forward(self, image):
-        x = self.encoder_cnn(image).movedim(1, -1)
-        x = self.encoder_pos(x)
-        x = self.mlp(self.layer_norm(x))
-        
+        x = self.encoder(image)
         slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
-        x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
-        x = self.decoder_pos(x)
-        x = self.decoder_cnn(x.movedim(-1, 1))
+        x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1)
         
+        x = self.decoder(x)
         x = F.interpolate(x, image.shape[-2:], mode='bilinear')
 
         x = x.unflatten(0, (len(image), len(x) // len(image)))
@@ -135,14 +107,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         return optimizer
     
     def one_step(self, image):
-        x = self.encoder_cnn(image).movedim(1, -1)
-        x = self.encoder_pos(x)
-        x = self.mlp(self.layer_norm(x))
+        x = self.encoder(image)
         
         slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
-        x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
-        x = self.decoder_pos(x)
-        x = self.decoder_cnn(x.movedim(-1, 1))
+        x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1)
+        x = self.decoder(x)
         
         x = F.interpolate(x, image.shape[-2:], mode='bilinear')
 
@@ -152,7 +121,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         masks = masks.softmax(dim = 1)
         recon_combined = (recons * masks).sum(dim = 1)
 
-        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
+        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise else None
 
     def training_step(self, batch, batch_idx):
         """Perform a single training step."""
diff --git a/outputs/visualisation_12000.png b/outputs/visualisation_12000.png
new file mode 100644
index 0000000000000000000000000000000000000000..b8e3a04cb6766c6d03e3f695d2f6dd99d7643aa7
Binary files /dev/null and b/outputs/visualisation_12000.png differ
diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml
index 161aaf628a68e847d514e2e78bdf5867ceb2167d..82c0f8e78d858da06e976697291a7189012fba51 100644
--- a/runs/clevr3d/slot_att/config.yaml
+++ b/runs/clevr3d/slot_att/config.yaml
@@ -1,13 +1,13 @@
 data:
   dataset: clevr3d
 model:
-  num_slots: 6
+  num_slots: 10
   iters: 3
   model_type: sa
 training:
   num_workers: 2 
   num_gpus: 1
-  batch_size: 64 
+  batch_size: 8 
   max_it: 333000000
   warmup_it: 10000
   decay_rate: 0.5
diff --git a/runs/clevr3d/slot_att/config_tsa.yaml b/runs/clevr3d/slot_att/config_tsa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f03fc97062dd583a0884019890f8fbd0778280b1
--- /dev/null
+++ b/runs/clevr3d/slot_att/config_tsa.yaml
@@ -0,0 +1,15 @@
+data:
+  dataset: clevr3d
+model:
+  num_slots: 10
+  iters: 3
+  model_type: tsa
+training:
+  num_workers: 2 
+  num_gpus: 1
+  batch_size: 32
+  max_it: 333000000
+  warmup_it: 10000
+  decay_rate: 0.5
+  decay_it: 100000
+
diff --git a/visualize_sa.py b/visualize_sa.py
index a1d4ce15204654a4018d62ecc25e34a6341c7a72..881d9d474000ee763ed1877310a79ebf2c755119 100644
--- a/visualize_sa.py
+++ b/visualize_sa.py
@@ -50,7 +50,7 @@ def main():
         shuffle=True, worker_init_fn=data.worker_init_fn)
 
     #### Create model
-    model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device)
+    model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device)
     checkpoint = torch.load(args.ckpt)
     
     model.load_state_dict(checkpoint['state_dict'])