From d25f228e781df65843527f258270ee56b8c7a874 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Tue, 25 Jul 2023 08:51:31 +0200 Subject: [PATCH] Update optimizer --- osrt/model.py | 4 ++-- train_sa.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/osrt/model.py b/osrt/model.py index f5ed08a..d989406 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -131,8 +131,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) def configure_optimizers(self) -> Any: - optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) - return optimizer + self.optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) + return self.optimizer def one_step(self, image): x = self.encoder_cnn(image).movedim(1, -1) diff --git a/train_sa.py b/train_sa.py index 269319c..a35c7c3 100644 --- a/train_sa.py +++ b/train_sa.py @@ -60,7 +60,6 @@ def main(): #### Create model model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) - checkpoint_callback = ModelCheckpoint( save_top_k=10, monitor="val_psnr", -- GitLab