diff --git a/osrt/model.py b/osrt/model.py index d98940686bc17d7a3dd24cc62708336feae0a011..bc816dadf4f180b8570d3a211eb0a64410f4c371 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -172,7 +172,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): self.optimizer.step() self.log('train_mse', loss_value, on_epoch=True) - return loss_value.item() + return {'train_mse': loss_value.item()} def validation_step(self, batch, batch_idx): """Perform a single eval step.""" @@ -189,5 +189,5 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): self.log('val_mse', loss_value) self.log('val_psnr', psnr) - return loss_value.item(), psnr.item() + return {'val_mse': loss_value.item(), 'val_psnr': psnr.item()}