diff --git a/osrt/layers.py b/osrt/layers.py index 9eaab5c63a443fff0ef5e2a938eb7c2cbed03b08..cd3d2c1d4a93d229d95207c8d5788cafbdce9892 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -319,7 +319,7 @@ class TransformerSlotAttention(nn.Module): for _ in range(depth): # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None): self.cs_layers.append(nn.ModuleList([ - PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, selfatt=False)), + PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False)), PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim)) ]))