From f8daeceeb55fce919c537d75f852fc54239ced70 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 24 Jul 2023 11:14:01 +0200 Subject: [PATCH] Change cross attention input dim --- osrt/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/osrt/layers.py b/osrt/layers.py index 9eaab5c..cd3d2c1 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -319,7 +319,7 @@ class TransformerSlotAttention(nn.Module): for _ in range(depth): # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None): self.cs_layers.append(nn.ModuleList([ - PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, selfatt=False)), + PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False)), PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim)) ])) -- GitLab