From c50a1ab916846ff2485b7ce7f4643a4d2b847f69 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Tue, 25 Jul 2023 10:42:45 +0200 Subject: [PATCH] Delete unused code for training model --- osrt/model.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/osrt/model.py b/osrt/model.py index ebd7a17..71511e6 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -131,8 +131,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) def configure_optimizers(self) -> Any: - self.optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) - return self.optimizer + optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) + return optimizer def one_step(self, image): x = self.encoder_cnn(image).movedim(1, -1) @@ -166,10 +166,6 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): loss_value = self.criterion(recon_combined, input_image) del recons, masks, slots # Unused. - # Get and apply gradients. - self.optimizer.zero_grad() - loss_value.backward() - self.optimizer.step() self.log('train_mse', loss_value, on_epoch=True) return {'loss': loss_value} -- GitLab