diff --git a/osrt/layers.py b/osrt/layers.py index d74ff57f8b4fd7bc7da7ca96666ba60700c26b33..6662e554a5fb2d151ea60c7cd0cf830266af2b74 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -354,8 +354,10 @@ class TransformerSlotAttention(nn.Module): """ batch_size, *axis = inputs.shape device = inputs.device + print(f"Shape inputs first : {inputs.shape}") - #inputs = self.norm_input(inputs) + inputs = self.norm_input(inputs) + print(f"Shape inputs after norm : {inputs.shape}") if self.randomize_initial_slots: slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim] @@ -371,9 +373,11 @@ class TransformerSlotAttention(nn.Module): enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size) inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1) + inputs = self.norm_input(inputs) + print(f"Shape inputs after encoding : {inputs.shape}") + for i in range(self.depth): cross_attn, cross_ff = self.cs_layers[i] - print(f"Shape inputs : {inputs}") x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual slots = cross_ff(x) + x # Feed-forward + Residual