From d74645d90679a6b22e57a40b47d468a93e9817e1 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 24 Jul 2023 17:10:43 +0200 Subject: [PATCH] Set criterion --- osrt/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/osrt/model.py b/osrt/model.py index 460a967..2f1864d 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -64,6 +64,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): self.num_slots = num_slots self.num_iterations = num_iterations + self.criterion = nn.MSELoss() + self.encoder_cnn = nn.Sequential( nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), @@ -152,7 +154,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) - def training_step(self, batch, criterion): + def training_step(self, batch, batch_idx): """Perform a single training step.""" input_image = torch.squeeze(batch.get('input_images'), dim=1) input_image = F.interpolate(input_image, size=128) @@ -161,7 +163,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): preds = self.one_step(input_image) recon_combined, recons, masks, slots, _ = preds input_image = input_image.permute(0, 2, 3, 1) - loss_value = criterion(recon_combined, input_image) + loss_value = self.criterion(recon_combined, input_image) del recons, masks, slots # Unused. # Get and apply gradients. @@ -172,7 +174,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return loss_value.item() - def validation_step(self, batch, criterion): + def validation_step(self, batch, batch_idx): """Perform a single eval step.""" input_image = torch.squeeze(batch.get('input_images'), dim=1) input_image = F.interpolate(input_image, size=128) @@ -181,7 +183,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): preds = self.one_step(input_image) recon_combined, recons, masks, slots, _ = preds input_image = input_image.permute(0, 2, 3, 1) - loss_value = criterion(recon_combined, input_image) + loss_value = self.criterion(recon_combined, input_image) del recons, masks, slots # Unused. psnr = mse2psnr(loss_value) self.log('val_mse', loss_value) -- GitLab