Skip to content
Snippets Groups Projects
Commit c50a1ab9 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Delete unused code for training model

parent 362130f8
No related branches found
No related tags found
No related merge requests found
...@@ -131,8 +131,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): ...@@ -131,8 +131,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
def configure_optimizers(self) -> Any: def configure_optimizers(self) -> Any:
self.optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
return self.optimizer return optimizer
def one_step(self, image): def one_step(self, image):
x = self.encoder_cnn(image).movedim(1, -1) x = self.encoder_cnn(image).movedim(1, -1)
...@@ -166,10 +166,6 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): ...@@ -166,10 +166,6 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
loss_value = self.criterion(recon_combined, input_image) loss_value = self.criterion(recon_combined, input_image)
del recons, masks, slots # Unused. del recons, masks, slots # Unused.
# Get and apply gradients.
self.optimizer.zero_grad()
loss_value.backward()
self.optimizer.step()
self.log('train_mse', loss_value, on_epoch=True) self.log('train_mse', loss_value, on_epoch=True)
return {'loss': loss_value} return {'loss': loss_value}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment