diff --git a/osrt/model.py b/osrt/model.py
index a653271b4437b557a97cbe34a25d11fa4ecb3cd1..460a967d95b7f28bf1f57b38f83e08a307be9ae3 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -159,7 +159,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
 
         # Get the prediction of the model and compute the loss.
         preds = self.one_step(input_image)
-        recon_combined, recons, masks, slots = preds
+        recon_combined, recons, masks, slots, _ = preds
         input_image = input_image.permute(0, 2, 3, 1)
         loss_value = criterion(recon_combined, input_image)
         del recons, masks, slots  # Unused.
@@ -179,7 +179,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
 
         # Get the prediction of the model and compute the loss.
         preds = self.one_step(input_image)
-        recon_combined, recons, masks, slots = preds
+        recon_combined, recons, masks, slots, _ = preds
         input_image = input_image.permute(0, 2, 3, 1)
         loss_value = criterion(recon_combined, input_image)
         del recons, masks, slots  # Unused.