diff --git a/osrt/model.py b/osrt/model.py index ebd7a171fda91fa085f68bff4f0d44311b33aad8..71511e635ca2dc8f9b8695bbab3a9bc2314cfa16 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -131,8 +131,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) def configure_optimizers(self) -> Any: - self.optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) - return self.optimizer + optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) + return optimizer def one_step(self, image): x = self.encoder_cnn(image).movedim(1, -1) @@ -166,10 +166,6 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): loss_value = self.criterion(recon_combined, input_image) del recons, masks, slots # Unused. - # Get and apply gradients. - self.optimizer.zero_grad() - loss_value.backward() - self.optimizer.step() self.log('train_mse', loss_value, on_epoch=True) return {'loss': loss_value}