diff --git a/osrt/layers.py b/osrt/layers.py index fbec56074ce6d4ec7c5e611f9f9406a756da761e..aadbdf50c6356436ab672f9f9c4f756ab1cbb445 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -103,8 +103,8 @@ class RayEncoder(nn.Module): class PreNorm(nn.Module): def __init__(self, dim, fn, cross_dim=None): super().__init__() - self.norm = nn.LayerNorm(dim) self.fn = fn + self.norm = nn.LayerNorm(dim) self.norm_cross = nn.LayerNorm(cross_dim) if cross_dim is not None else None def forward(self, x, **kwargs): @@ -112,7 +112,7 @@ class PreNorm(nn.Module): if self.norm_cross is not None: z = kwargs['z'] normed_context = self.norm_cross(z) - kwargs.update(cross_val = normed_context) + kwargs.update(z = normed_context) return self.fn(x, **kwargs) @@ -373,7 +373,8 @@ class TransformerSlotAttention(nn.Module): for i in range(self.depth): cross_attn, cross_ff = self.cs_layers[i] - x = cross_attn(slots, inputs) + slots # Cross-attention + Residual + print(f"Shape slots {slots.shape} an inputs shape {inputs.shape}") + x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual slots = cross_ff(x) + x # Feed-forward + Residual ## Apply self-attention on input tokens but only before last depth layer