From 50ab36b3696bbe52ae27b1d80d3d29e83c2f7a70 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Wed, 19 Jul 2023 15:33:16 +0200
Subject: [PATCH] Fix norm3

---
 osrt/layers.py          | 4 ++++
 osrt/sam/transformer.py | 2 +-
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/osrt/layers.py b/osrt/layers.py
index be1cbf9..e04e74d 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -261,3 +261,7 @@ class SlotAttention(nn.Module):
     def change_slots_number(self, num_slots): 
         self.num_slots = num_slots
         self.initial_slots = nn.Parameter(torch.randn(num_slots, self.slot_dim))
+
+
+class TransformerSlotAttention(nn.Module):
+    pass
\ No newline at end of file
diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
index b3fc70a..3758299 100644
--- a/osrt/sam/transformer.py
+++ b/osrt/sam/transformer.py
@@ -174,8 +174,8 @@ class TwoWayAttentionBlock(nn.Module):
         self.mlp = self.mlp.to(queries.device)
         mlp_out = self.mlp(queries)
         queries = queries + mlp_out
-        queries = self.norm3(queries)
         self.norm3 = self.norm3.to(queries.device)
+        queries = self.norm3(queries)
 
         # Cross attention block, image embedding attending to tokens
         q = queries + query_pe
-- 
GitLab