Skip to content
Snippets Groups Projects
Commit 552a761b authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Apply prenorm on cross-attn

parent 29d94b3f
No related branches found
No related tags found
No related merge requests found
......@@ -100,13 +100,19 @@ class RayEncoder(nn.Module):
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
class PreNorm(nn.Module):
def __init__(self, dim, fn):
def __init__(self, dim, fn, cross_dim=None):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
self.norm_cross = nn.LayerNorm(cross_dim) if cross_dim is not None else None
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
x = self.norm(x)
if self.norm_cross is not None:
z = kwargs['z']
normed_context = self.norm_cross(z)
kwargs.update(cross_val = normed_context)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
......@@ -334,7 +340,7 @@ class TransformerSlotAttention(nn.Module):
for _ in range(depth):
# def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
self.cs_layers.append(nn.ModuleList([
PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False)),
PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False), self.input_dim),
PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim))
]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment