Skip to content
Snippets Groups Projects
Commit f367b047 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix param SlotAttAutoEncod

parent 20e7b9e9
No related branches found
No related tags found
No related merge requests found
...@@ -125,7 +125,7 @@ class SlotAttentionAutoEncoder(nn.Module): ...@@ -125,7 +125,7 @@ class SlotAttentionAutoEncoder(nn.Module):
input_dim=64, input_dim=64,
slot_dim=64, slot_dim=64,
hidden_dim=128, hidden_dim=128,
iters=self.num_iterations) depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model
def forward(self, image): def forward(self, image):
# `image` has shape: [batch_size, num_channels, width, height]. # `image` has shape: [batch_size, num_channels, width, height].
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment