Skip to content
Snippets Groups Projects
Commit c65c403e authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix slot-attention

parent 735336c1
No related branches found
No related tags found
No related merge requests found
...@@ -244,8 +244,7 @@ class SlotAttention(nn.Module): ...@@ -244,8 +244,7 @@ class SlotAttention(nn.Module):
k, v = self.to_k(inputs), self.to_v(inputs) k, v = self.to_k(inputs), self.to_v(inputs)
if slots is None: slots = self.slots_mu + torch.exp(self.slots_log_sigma) * torch.randn(len(inputs), self.num_slots, self.slot_size, device = self.slots_mu.device)
slots = self.slots_mu + torch.exp(self.slots_log_sigma) * torch.randn(len(inputs), self.num_slots, self.slot_size, device = self.slots_mu.device)
# Multiple rounds of attention. # Multiple rounds of attention.
for _ in range(self.iters): for _ in range(self.iters):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment