From bd3fd53ff3299338f95932f8ce1ee95c978c005f Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 24 Jul 2023 10:24:16 +0200 Subject: [PATCH] Change MSE loss --- osrt/layers.py | 1 - train_sa.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/osrt/layers.py b/osrt/layers.py index ef2d03e..0cfcff1 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -295,7 +295,6 @@ class SoftPositionEmbed(nn.Module): def forward(self, inputs): return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w] - ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py class TransformerSlotAttention(nn.Module): """ diff --git a/train_sa.py b/train_sa.py index d973ec9..09e0c1e 100644 --- a/train_sa.py +++ b/train_sa.py @@ -43,7 +43,7 @@ def eval_step(batch, model, device): preds = model(input_image) recon_combined, recons, masks, slots = preds input_image = input_image.permute(0, 2, 3, 1) - loss_value = nn.MSELoss(recon_combined, input_image) + loss_value = F.mse_loss(recon_combined, input_image) del recons, masks, slots # Unused. psnr = mse2psnr(loss_value) -- GitLab