Skip to content
Snippets Groups Projects
Commit e6e1600b authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Get some logs

parent e2d5300d
No related branches found
No related tags found
No related merge requests found
......@@ -354,8 +354,10 @@ class TransformerSlotAttention(nn.Module):
"""
batch_size, *axis = inputs.shape
device = inputs.device
print(f"Shape inputs first : {inputs.shape}")
#inputs = self.norm_input(inputs)
inputs = self.norm_input(inputs)
print(f"Shape inputs after norm : {inputs.shape}")
if self.randomize_initial_slots:
slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
......@@ -371,9 +373,11 @@ class TransformerSlotAttention(nn.Module):
enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size)
inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1)
inputs = self.norm_input(inputs)
print(f"Shape inputs after encoding : {inputs.shape}")
for i in range(self.depth):
cross_attn, cross_ff = self.cs_layers[i]
print(f"Shape inputs : {inputs}")
x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual
slots = cross_ff(x) + x # Feed-forward + Residual
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment