From bd3fd53ff3299338f95932f8ce1ee95c978c005f Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Mon, 24 Jul 2023 10:24:16 +0200
Subject: [PATCH] Change MSE loss

---
 osrt/layers.py | 1 -
 train_sa.py    | 2 +-
 2 files changed, 1 insertion(+), 2 deletions(-)

diff --git a/osrt/layers.py b/osrt/layers.py
index ef2d03e..0cfcff1 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -295,7 +295,6 @@ class SoftPositionEmbed(nn.Module):
     def forward(self, inputs):
         return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) #  from [b, h, w, c] to [b, c, h, w]
     
-    
 ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
 class TransformerSlotAttention(nn.Module):
     """
diff --git a/train_sa.py b/train_sa.py
index d973ec9..09e0c1e 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -43,7 +43,7 @@ def eval_step(batch, model, device):
     preds = model(input_image)
     recon_combined, recons, masks, slots = preds
     input_image = input_image.permute(0, 2, 3, 1)
-    loss_value = nn.MSELoss(recon_combined, input_image)
+    loss_value = F.mse_loss(recon_combined, input_image)
     del recons, masks, slots  # Unused.
     psnr = mse2psnr(loss_value)
 
-- 
GitLab