Skip to content
Snippets Groups Projects
Commit 78bc3df2 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add logs'

parent f02c21f9
No related branches found
No related tags found
No related merge requests found
...@@ -103,8 +103,8 @@ class RayEncoder(nn.Module): ...@@ -103,8 +103,8 @@ class RayEncoder(nn.Module):
class PreNorm(nn.Module): class PreNorm(nn.Module):
def __init__(self, dim, fn, cross_dim=None): def __init__(self, dim, fn, cross_dim=None):
super().__init__() super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_cross = nn.LayerNorm(cross_dim) if cross_dim is not None else None self.norm_cross = nn.LayerNorm(cross_dim) if cross_dim is not None else None
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
...@@ -112,7 +112,7 @@ class PreNorm(nn.Module): ...@@ -112,7 +112,7 @@ class PreNorm(nn.Module):
if self.norm_cross is not None: if self.norm_cross is not None:
z = kwargs['z'] z = kwargs['z']
normed_context = self.norm_cross(z) normed_context = self.norm_cross(z)
kwargs.update(cross_val = normed_context) kwargs.update(z = normed_context)
return self.fn(x, **kwargs) return self.fn(x, **kwargs)
...@@ -373,7 +373,8 @@ class TransformerSlotAttention(nn.Module): ...@@ -373,7 +373,8 @@ class TransformerSlotAttention(nn.Module):
for i in range(self.depth): for i in range(self.depth):
cross_attn, cross_ff = self.cs_layers[i] cross_attn, cross_ff = self.cs_layers[i]
x = cross_attn(slots, inputs) + slots # Cross-attention + Residual print(f"Shape slots {slots.shape} an inputs shape {inputs.shape}")
x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual
slots = cross_ff(x) + x # Feed-forward + Residual slots = cross_ff(x) + x # Feed-forward + Residual
## Apply self-attention on input tokens but only before last depth layer ## Apply self-attention on input tokens but only before last depth layer
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment