From f8daeceeb55fce919c537d75f852fc54239ced70 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Mon, 24 Jul 2023 11:14:01 +0200
Subject: [PATCH] Change cross attention input dim

---
 osrt/layers.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/osrt/layers.py b/osrt/layers.py
index 9eaab5c..cd3d2c1 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -319,7 +319,7 @@ class TransformerSlotAttention(nn.Module):
         for _ in range(depth):
             # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
             self.cs_layers.append(nn.ModuleList([
-                PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, selfatt=False)),
+                PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False)),
                 PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim))
             ]))
 
-- 
GitLab