diff --git a/osrt/model.py b/osrt/model.py
index 2f1864d568aa045bf9786b725f1375862636feaf..f5ed08ac75e6257e78427373ba6bac1d3359ef13 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -162,7 +162,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         # Get the prediction of the model and compute the loss.
         preds = self.one_step(input_image)
         recon_combined, recons, masks, slots, _ = preds
-        input_image = input_image.permute(0, 2, 3, 1)
+        #input_image = input_image.permute(0, 2, 3, 1)
         loss_value = self.criterion(recon_combined, input_image)
         del recons, masks, slots  # Unused.
 
@@ -182,7 +182,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         # Get the prediction of the model and compute the loss.
         preds = self.one_step(input_image)
         recon_combined, recons, masks, slots, _ = preds
-        input_image = input_image.permute(0, 2, 3, 1)
+        #input_image = input_image.permute(0, 2, 3, 1)
         loss_value = self.criterion(recon_combined, input_image)
         del recons, masks, slots  # Unused.
         psnr = mse2psnr(loss_value)