diff --git a/osrt/layers.py b/osrt/layers.py index 19f9395e53cf9aa782fe77a5d7ef729a083f3874..fbec56074ce6d4ec7c5e611f9f9406a756da761e 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -244,7 +244,7 @@ class SlotAttention(nn.Module): k, v = self.to_k(inputs), self.to_v(inputs) - slots = self.slots_mu + torch.exp(self.slots_log_sigma) * torch.randn(len(inputs), self.num_slots, self.slot_size, device = self.slots_mu.device) + slots = self.slots_mu + torch.exp(self.slots_log_sigma) * torch.randn(len(inputs), self.num_slots, self.slot_dim, device = self.slots_mu.device) # Multiple rounds of attention. for _ in range(self.iters):