diff --git a/osrt/model.py b/osrt/model.py index 29ac38ef8d4c26ea05a0df2e3b3530d88ba9d0ff..955d19885ec6642630a668d6914fd16fb68fa4aa 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -129,7 +129,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) def configure_optimizers(self) -> Any: - optimizer = optim.Adam(self.parameters, lr=1e-3, eps=1e-08) + optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) return optimizer def one_step(self, image): @@ -137,7 +137,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): x = self.encoder_pos(x) x = self.mlp(self.layer_norm(x)) - slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots) + slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2)) x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1) x = self.decoder_pos(x) x = self.decoder_cnn(x.movedim(-1, 1))