Skip to content
Snippets Groups Projects
Commit d74645d9 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Set criterion

parent 2802499a
No related branches found
No related tags found
No related merge requests found
......@@ -64,6 +64,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
self.num_slots = num_slots
self.num_iterations = num_iterations
self.criterion = nn.MSELoss()
self.encoder_cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
......@@ -152,7 +154,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
def training_step(self, batch, criterion):
def training_step(self, batch, batch_idx):
"""Perform a single training step."""
input_image = torch.squeeze(batch.get('input_images'), dim=1)
input_image = F.interpolate(input_image, size=128)
......@@ -161,7 +163,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
preds = self.one_step(input_image)
recon_combined, recons, masks, slots, _ = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = criterion(recon_combined, input_image)
loss_value = self.criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
# Get and apply gradients.
......@@ -172,7 +174,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
return loss_value.item()
def validation_step(self, batch, criterion):
def validation_step(self, batch, batch_idx):
"""Perform a single eval step."""
input_image = torch.squeeze(batch.get('input_images'), dim=1)
input_image = F.interpolate(input_image, size=128)
......@@ -181,7 +183,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
preds = self.one_step(input_image)
recon_combined, recons, masks, slots, _ = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = criterion(recon_combined, input_image)
loss_value = self.criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value)
self.log('val_mse', loss_value)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment