From cf9cc95195e639c9dba732e4d494af5c87fd7e3d Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 24 Jul 2023 17:15:21 +0200 Subject: [PATCH] Loss shape test --- osrt/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/osrt/model.py b/osrt/model.py index 2f1864d..f5ed08a 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -162,7 +162,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): # Get the prediction of the model and compute the loss. preds = self.one_step(input_image) recon_combined, recons, masks, slots, _ = preds - input_image = input_image.permute(0, 2, 3, 1) + #input_image = input_image.permute(0, 2, 3, 1) loss_value = self.criterion(recon_combined, input_image) del recons, masks, slots # Unused. @@ -182,7 +182,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): # Get the prediction of the model and compute the loss. preds = self.one_step(input_image) recon_combined, recons, masks, slots, _ = preds - input_image = input_image.permute(0, 2, 3, 1) + #input_image = input_image.permute(0, 2, 3, 1) loss_value = self.criterion(recon_combined, input_image) del recons, masks, slots # Unused. psnr = mse2psnr(loss_value) -- GitLab