diff --git a/osrt/layers.py b/osrt/layers.py
index 6662e554a5fb2d151ea60c7cd0cf830266af2b74..24c363f866c6b003ff4c1e297b0e84a2bbd90258 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -327,9 +327,9 @@ class TransformerSlotAttention(nn.Module):
         ### Cross-attention layers
         self.cs_layers = nn.ModuleList([])
         for _ in range(depth):
-            # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
+            # The +26 value is due to the positional encoding added 
             self.cs_layers.append(nn.ModuleList([
-                PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False), self.input_dim),
+                PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim+26, selfatt=False), self.input_dim + 26),
                 PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim))
             ]))
 
@@ -337,8 +337,8 @@ class TransformerSlotAttention(nn.Module):
         self.sf_layers = nn.ModuleList([])
         for _ in range(depth-1):
             self.sf_layers.append(nn.ModuleList([
-                PreNorm(self.input_dim, Attention(self.input_dim, heads=self.self_head, dim_head = self.hidden_dim)),
-                PreNorm(self.input_dim, FeedForward(self.input_dim, self.hidden_dim))
+                PreNorm(self.input_dim+26, Attention(self.input_dim+26, heads=self.self_head, dim_head = self.hidden_dim)),
+                PreNorm(self.input_dim+26, FeedForward(self.input_dim+26, self.hidden_dim))
             ]))
 
         ### Initialize slots
@@ -354,10 +354,6 @@ class TransformerSlotAttention(nn.Module):
         """
         batch_size, *axis = inputs.shape
         device = inputs.device
-        print(f"Shape inputs first : {inputs.shape}")
-
-        inputs = self.norm_input(inputs)
-        print(f"Shape inputs after norm : {inputs.shape}")
         
         if self.randomize_initial_slots:
             slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
@@ -370,11 +366,10 @@ class TransformerSlotAttention(nn.Module):
         pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
         enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
         enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
-        enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size)
+        enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size) # Size = 26
+
         inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1) 
 
-        inputs = self.norm_input(inputs)
-        print(f"Shape inputs after encoding : {inputs.shape}")
 
         for i in range(self.depth):
             cross_attn, cross_ff = self.cs_layers[i]