From 735336c160dfc5e093f9d4634715659d7c55b378 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Mon, 24 Jul 2023 16:14:13 +0200
Subject: [PATCH] Fix issue with model

---
 osrt/model.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/osrt/model.py b/osrt/model.py
index 29ac38e..955d198 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -129,7 +129,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
     
     def configure_optimizers(self) -> Any:
-        optimizer = optim.Adam(self.parameters, lr=1e-3, eps=1e-08)
+        optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
         return optimizer
     
     def one_step(self, image):
@@ -137,7 +137,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         x = self.encoder_pos(x)
         x = self.mlp(self.layer_norm(x))
         
-        slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots)
+        slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
         x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
         x = self.decoder_pos(x)
         x = self.decoder_cnn(x.movedim(-1, 1))
-- 
GitLab