Skip to content
Snippets Groups Projects
Commit 50ab36b3 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix norm3

parent 569a8b38
No related branches found
No related tags found
No related merge requests found
......@@ -261,3 +261,7 @@ class SlotAttention(nn.Module):
def change_slots_number(self, num_slots):
self.num_slots = num_slots
self.initial_slots = nn.Parameter(torch.randn(num_slots, self.slot_dim))
class TransformerSlotAttention(nn.Module):
pass
\ No newline at end of file
......@@ -174,8 +174,8 @@ class TwoWayAttentionBlock(nn.Module):
self.mlp = self.mlp.to(queries.device)
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
self.norm3 = self.norm3.to(queries.device)
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment