diff --git a/osrt/layers.py b/osrt/layers.py
index d74ff57f8b4fd7bc7da7ca96666ba60700c26b33..6662e554a5fb2d151ea60c7cd0cf830266af2b74 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -354,8 +354,10 @@ class TransformerSlotAttention(nn.Module):
         """
         batch_size, *axis = inputs.shape
         device = inputs.device
+        print(f"Shape inputs first : {inputs.shape}")
 
-        #inputs = self.norm_input(inputs)
+        inputs = self.norm_input(inputs)
+        print(f"Shape inputs after norm : {inputs.shape}")
         
         if self.randomize_initial_slots:
             slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
@@ -371,9 +373,11 @@ class TransformerSlotAttention(nn.Module):
         enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size)
         inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1) 
 
+        inputs = self.norm_input(inputs)
+        print(f"Shape inputs after encoding : {inputs.shape}")
+
         for i in range(self.depth):
             cross_attn, cross_ff = self.cs_layers[i]
-            print(f"Shape inputs : {inputs}")
             x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual
             slots = cross_ff(x) + x # Feed-forward + Residual