diff --git a/osrt/layers.py b/osrt/layers.py
index fbec56074ce6d4ec7c5e611f9f9406a756da761e..aadbdf50c6356436ab672f9f9c4f756ab1cbb445 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -103,8 +103,8 @@ class RayEncoder(nn.Module):
 class PreNorm(nn.Module):
     def __init__(self, dim, fn, cross_dim=None):
         super().__init__()
-        self.norm = nn.LayerNorm(dim)
         self.fn = fn
+        self.norm = nn.LayerNorm(dim)
         self.norm_cross = nn.LayerNorm(cross_dim) if cross_dim is not None else None
 
     def forward(self, x, **kwargs):
@@ -112,7 +112,7 @@ class PreNorm(nn.Module):
         if self.norm_cross is not None:
             z = kwargs['z']
             normed_context = self.norm_cross(z)
-            kwargs.update(cross_val = normed_context)
+            kwargs.update(z = normed_context)
         return self.fn(x, **kwargs)
 
 
@@ -373,7 +373,8 @@ class TransformerSlotAttention(nn.Module):
 
         for i in range(self.depth):
             cross_attn, cross_ff = self.cs_layers[i]
-            x = cross_attn(slots, inputs) + slots # Cross-attention + Residual
+            print(f"Shape slots {slots.shape} an inputs shape {inputs.shape}")
+            x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual
             slots = cross_ff(x) + x # Feed-forward + Residual
 
             ## Apply self-attention on input tokens but only before last depth layer