Skip to content
Snippets Groups Projects
Commit bd3fd53f authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change MSE loss

parent 6d55e14b
No related branches found
No related tags found
No related merge requests found
...@@ -295,7 +295,6 @@ class SoftPositionEmbed(nn.Module): ...@@ -295,7 +295,6 @@ class SoftPositionEmbed(nn.Module):
def forward(self, inputs): def forward(self, inputs):
return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w] return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w]
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
class TransformerSlotAttention(nn.Module): class TransformerSlotAttention(nn.Module):
""" """
......
...@@ -43,7 +43,7 @@ def eval_step(batch, model, device): ...@@ -43,7 +43,7 @@ def eval_step(batch, model, device):
preds = model(input_image) preds = model(input_image)
recon_combined, recons, masks, slots = preds recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1) input_image = input_image.permute(0, 2, 3, 1)
loss_value = nn.MSELoss(recon_combined, input_image) loss_value = F.mse_loss(recon_combined, input_image)
del recons, masks, slots # Unused. del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value) psnr = mse2psnr(loss_value)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment