From e6e1600bd35eddfb1cc15ecccf72839ccc0f5736 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Tue, 25 Jul 2023 15:54:23 +0200
Subject: [PATCH] Get some logs

---
 osrt/layers.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/osrt/layers.py b/osrt/layers.py
index d74ff57..6662e55 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -354,8 +354,10 @@ class TransformerSlotAttention(nn.Module):
         """
         batch_size, *axis = inputs.shape
         device = inputs.device
+        print(f"Shape inputs first : {inputs.shape}")
 
-        #inputs = self.norm_input(inputs)
+        inputs = self.norm_input(inputs)
+        print(f"Shape inputs after norm : {inputs.shape}")
         
         if self.randomize_initial_slots:
             slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
@@ -371,9 +373,11 @@ class TransformerSlotAttention(nn.Module):
         enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size)
         inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1) 
 
+        inputs = self.norm_input(inputs)
+        print(f"Shape inputs after encoding : {inputs.shape}")
+
         for i in range(self.depth):
             cross_attn, cross_ff = self.cs_layers[i]
-            print(f"Shape inputs : {inputs}")
             x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual
             slots = cross_ff(x) + x # Feed-forward + Residual
 
-- 
GitLab