diff --git a/osrt/layers.py b/osrt/layers.py
index aadbdf50c6356436ab672f9f9c4f756ab1cbb445..d74ff57f8b4fd7bc7da7ca96666ba60700c26b33 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -355,7 +355,7 @@ class TransformerSlotAttention(nn.Module):
         batch_size, *axis = inputs.shape
         device = inputs.device
 
-        inputs = self.norm_input(inputs)
+        #inputs = self.norm_input(inputs)
         
         if self.randomize_initial_slots:
             slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
@@ -368,12 +368,12 @@ class TransformerSlotAttention(nn.Module):
         pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
         enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
         enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
-        enc_pos = repeat(enc_pos, '... -> b ...', b = batch_size)
+        enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size)
         inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1) 
 
         for i in range(self.depth):
             cross_attn, cross_ff = self.cs_layers[i]
-            print(f"Shape slots {slots.shape} an inputs shape {inputs.shape}")
+            print(f"Shape inputs : {inputs}")
             x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual
             slots = cross_ff(x) + x # Feed-forward + Residual
 
diff --git a/outputs/visualisation_8000.png b/outputs/visualisation_8000.png
new file mode 100644
index 0000000000000000000000000000000000000000..84ae3fad1a4d8968c3bd89ebb65bbba276509a3b
Binary files /dev/null and b/outputs/visualisation_8000.png differ
diff --git a/outputs/visualisation_9000.png b/outputs/visualisation_9000.png
new file mode 100644
index 0000000000000000000000000000000000000000..55588a785d95b9e18267c5003bd6813c6550f3a6
Binary files /dev/null and b/outputs/visualisation_9000.png differ
diff --git a/visualize_sa.py b/visualize_sa.py
index 5d4c67f67a093fafe0a9f9fcc2450e10d84fa4a9..a1d4ce15204654a4018d62ecc25e34a6341c7a72 100644
--- a/visualize_sa.py
+++ b/visualize_sa.py
@@ -50,7 +50,7 @@ def main():
         shuffle=True, worker_init_fn=data.worker_init_fn)
 
     #### Create model
-    model = LitSlotAttentionAutoEncoder(resolution, 6, num_iterations, cfg=cfg).to(device)
+    model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device)
     checkpoint = torch.load(args.ckpt)
     
     model.load_state_dict(checkpoint['state_dict'])