diff --git a/osrt/model.py b/osrt/model.py index 0db162a5165b5d1f400d8bf5e2f4c99e50190f5a..0102d6153d02118b3866e8bb6fd4adb8b1d0dcb9 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -125,7 +125,7 @@ class SlotAttentionAutoEncoder(nn.Module): input_dim=64, slot_dim=64, hidden_dim=128, - iters=self.num_iterations) + depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model def forward(self, image): # `image` has shape: [batch_size, num_channels, width, height].