diff --git a/osrt/layers.py b/osrt/layers.py index be1cbf9bbd238a298bff3c377837ce137efd157f..e04e74d07b777be251662e26814e8c1234e612f3 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -261,3 +261,7 @@ class SlotAttention(nn.Module): def change_slots_number(self, num_slots): self.num_slots = num_slots self.initial_slots = nn.Parameter(torch.randn(num_slots, self.slot_dim)) + + +class TransformerSlotAttention(nn.Module): + pass \ No newline at end of file diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py index b3fc70a6407bb83c1aa61e9bb4cc24b72a81776a..375829919093629588057f1bd384c1a103e7aae0 100644 --- a/osrt/sam/transformer.py +++ b/osrt/sam/transformer.py @@ -174,8 +174,8 @@ class TwoWayAttentionBlock(nn.Module): self.mlp = self.mlp.to(queries.device) mlp_out = self.mlp(queries) queries = queries + mlp_out - queries = self.norm3(queries) self.norm3 = self.norm3.to(queries.device) + queries = self.norm3(queries) # Cross attention block, image embedding attending to tokens q = queries + query_pe