From e6e1600bd35eddfb1cc15ecccf72839ccc0f5736 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Tue, 25 Jul 2023 15:54:23 +0200 Subject: [PATCH] Get some logs --- osrt/layers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/osrt/layers.py b/osrt/layers.py index d74ff57..6662e55 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -354,8 +354,10 @@ class TransformerSlotAttention(nn.Module): """ batch_size, *axis = inputs.shape device = inputs.device + print(f"Shape inputs first : {inputs.shape}") - #inputs = self.norm_input(inputs) + inputs = self.norm_input(inputs) + print(f"Shape inputs after norm : {inputs.shape}") if self.randomize_initial_slots: slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim] @@ -371,9 +373,11 @@ class TransformerSlotAttention(nn.Module): enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size) inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1) + inputs = self.norm_input(inputs) + print(f"Shape inputs after encoding : {inputs.shape}") + for i in range(self.depth): cross_attn, cross_ff = self.cs_layers[i] - print(f"Shape inputs : {inputs}") x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual slots = cross_ff(x) + x # Feed-forward + Residual -- GitLab