diff --git a/osrt/model.py b/osrt/model.py
index ebd7a171fda91fa085f68bff4f0d44311b33aad8..71511e635ca2dc8f9b8695bbab3a9bc2314cfa16 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -131,8 +131,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
     
     def configure_optimizers(self) -> Any:
-        self.optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
-        return self.optimizer
+        optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
+        return optimizer
     
     def one_step(self, image):
         x = self.encoder_cnn(image).movedim(1, -1)
@@ -166,10 +166,6 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         loss_value = self.criterion(recon_combined, input_image)
         del recons, masks, slots  # Unused.
 
-        # Get and apply gradients.
-        self.optimizer.zero_grad()
-        loss_value.backward()
-        self.optimizer.step()
         self.log('train_mse', loss_value, on_epoch=True)
 
         return {'loss': loss_value}