Skip to content
Snippets Groups Projects
Commit d25f228e authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Update optimizer

parent cf9cc951
No related branches found
No related tags found
No related merge requests found
...@@ -131,8 +131,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): ...@@ -131,8 +131,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
def configure_optimizers(self) -> Any: def configure_optimizers(self) -> Any:
optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) self.optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
return optimizer return self.optimizer
def one_step(self, image): def one_step(self, image):
x = self.encoder_cnn(image).movedim(1, -1) x = self.encoder_cnn(image).movedim(1, -1)
......
...@@ -60,7 +60,6 @@ def main(): ...@@ -60,7 +60,6 @@ def main():
#### Create model #### Create model
model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
save_top_k=10, save_top_k=10,
monitor="val_psnr", monitor="val_psnr",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment