Skip to content
Snippets Groups Projects
Commit cf9cc951 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Loss shape test

parent d74645d9
No related branches found
No related tags found
No related merge requests found
...@@ -162,7 +162,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): ...@@ -162,7 +162,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
# Get the prediction of the model and compute the loss. # Get the prediction of the model and compute the loss.
preds = self.one_step(input_image) preds = self.one_step(input_image)
recon_combined, recons, masks, slots, _ = preds recon_combined, recons, masks, slots, _ = preds
input_image = input_image.permute(0, 2, 3, 1) #input_image = input_image.permute(0, 2, 3, 1)
loss_value = self.criterion(recon_combined, input_image) loss_value = self.criterion(recon_combined, input_image)
del recons, masks, slots # Unused. del recons, masks, slots # Unused.
...@@ -182,7 +182,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): ...@@ -182,7 +182,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
# Get the prediction of the model and compute the loss. # Get the prediction of the model and compute the loss.
preds = self.one_step(input_image) preds = self.one_step(input_image)
recon_combined, recons, masks, slots, _ = preds recon_combined, recons, masks, slots, _ = preds
input_image = input_image.permute(0, 2, 3, 1) #input_image = input_image.permute(0, 2, 3, 1)
loss_value = self.criterion(recon_combined, input_image) loss_value = self.criterion(recon_combined, input_image)
del recons, masks, slots # Unused. del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value) psnr = mse2psnr(loss_value)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment