From d74645d90679a6b22e57a40b47d468a93e9817e1 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Mon, 24 Jul 2023 17:10:43 +0200
Subject: [PATCH] Set criterion

---
 osrt/model.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/osrt/model.py b/osrt/model.py
index 460a967..2f1864d 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -64,6 +64,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         self.num_slots = num_slots
         self.num_iterations = num_iterations
 
+        self.criterion = nn.MSELoss()
+
         self.encoder_cnn = nn.Sequential(
             nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
             nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
@@ -152,7 +154,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
 
         return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
 
-    def training_step(self, batch, criterion):
+    def training_step(self, batch, batch_idx):
         """Perform a single training step."""
         input_image = torch.squeeze(batch.get('input_images'), dim=1)
         input_image = F.interpolate(input_image, size=128)
@@ -161,7 +163,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         preds = self.one_step(input_image)
         recon_combined, recons, masks, slots, _ = preds
         input_image = input_image.permute(0, 2, 3, 1)
-        loss_value = criterion(recon_combined, input_image)
+        loss_value = self.criterion(recon_combined, input_image)
         del recons, masks, slots  # Unused.
 
         # Get and apply gradients.
@@ -172,7 +174,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
 
         return loss_value.item()
     
-    def validation_step(self, batch, criterion):
+    def validation_step(self, batch, batch_idx):
         """Perform a single eval step."""
         input_image = torch.squeeze(batch.get('input_images'), dim=1)
         input_image = F.interpolate(input_image, size=128)
@@ -181,7 +183,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         preds = self.one_step(input_image)
         recon_combined, recons, masks, slots, _ = preds
         input_image = input_image.permute(0, 2, 3, 1)
-        loss_value = criterion(recon_combined, input_image)
+        loss_value = self.criterion(recon_combined, input_image)
         del recons, masks, slots  # Unused.
         psnr = mse2psnr(loss_value)
         self.log('val_mse', loss_value)
-- 
GitLab