diff --git a/osrt/layers.py b/osrt/layers.py
index 9eaab5c63a443fff0ef5e2a938eb7c2cbed03b08..cd3d2c1d4a93d229d95207c8d5788cafbdce9892 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -319,7 +319,7 @@ class TransformerSlotAttention(nn.Module):
         for _ in range(depth):
             # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
             self.cs_layers.append(nn.ModuleList([
-                PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, selfatt=False)),
+                PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False)),
                 PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim))
             ]))