From 735336c160dfc5e093f9d4634715659d7c55b378 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 24 Jul 2023 16:14:13 +0200 Subject: [PATCH] Fix issue with model --- osrt/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/osrt/model.py b/osrt/model.py index 29ac38e..955d198 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -129,7 +129,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) def configure_optimizers(self) -> Any: - optimizer = optim.Adam(self.parameters, lr=1e-3, eps=1e-08) + optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) return optimizer def one_step(self, image): @@ -137,7 +137,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): x = self.encoder_pos(x) x = self.mlp(self.layer_norm(x)) - slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots) + slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2)) x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1) x = self.decoder_pos(x) x = self.decoder_cnn(x.movedim(-1, 1)) -- GitLab