From d25f228e781df65843527f258270ee56b8c7a874 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Tue, 25 Jul 2023 08:51:31 +0200
Subject: [PATCH] Update optimizer

---
 osrt/model.py | 4 ++--
 train_sa.py   | 1 -
 2 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/osrt/model.py b/osrt/model.py
index f5ed08a..d989406 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -131,8 +131,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
     
     def configure_optimizers(self) -> Any:
-        optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
-        return optimizer
+        self.optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
+        return self.optimizer
     
     def one_step(self, image):
         x = self.encoder_cnn(image).movedim(1, -1)
diff --git a/train_sa.py b/train_sa.py
index 269319c..a35c7c3 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -60,7 +60,6 @@ def main():
     #### Create model
     model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
 
-
     checkpoint_callback = ModelCheckpoint(
         save_top_k=10,
         monitor="val_psnr",
-- 
GitLab