Skip to content
Snippets Groups Projects
Commit f8daecee authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change cross attention input dim

parent f367b047
No related branches found
No related tags found
No related merge requests found
...@@ -319,7 +319,7 @@ class TransformerSlotAttention(nn.Module): ...@@ -319,7 +319,7 @@ class TransformerSlotAttention(nn.Module):
for _ in range(depth): for _ in range(depth):
# def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None): # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
self.cs_layers.append(nn.ModuleList([ self.cs_layers.append(nn.ModuleList([
PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, selfatt=False)), PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False)),
PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim)) PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim))
])) ]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment