From c65c403ef24be96fb202f4cc3e4c3b2a7479998c Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Mon, 24 Jul 2023 16:20:00 +0200
Subject: [PATCH] Fix slot-attention

---
 osrt/layers.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/osrt/layers.py b/osrt/layers.py
index deaf9d8..19f9395 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -244,8 +244,7 @@ class SlotAttention(nn.Module):
 
         k, v = self.to_k(inputs), self.to_v(inputs)
 
-        if slots is None:
-            slots = self.slots_mu + torch.exp(self.slots_log_sigma) * torch.randn(len(inputs), self.num_slots, self.slot_size, device = self.slots_mu.device)
+        slots = self.slots_mu + torch.exp(self.slots_log_sigma) * torch.randn(len(inputs), self.num_slots, self.slot_size, device = self.slots_mu.device)
 
         # Multiple rounds of attention.
         for _ in range(self.iters):
-- 
GitLab