Skip to content
Snippets Groups Projects
Commit 1915d2fe authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix loss issues

parent bd3fd53f
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,7 @@ from torch.utils.data import DataLoader ...@@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
import torch.nn.functional as F import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
def train_step(batch, model, optimizer, device): def train_step(batch, model, optimizer, device, criterion):
"""Perform a single training step.""" """Perform a single training step."""
input_image = torch.squeeze(batch.get('input_images').to(device), dim=1) input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
input_image = F.interpolate(input_image, size=128) input_image = F.interpolate(input_image, size=128)
...@@ -24,7 +24,7 @@ def train_step(batch, model, optimizer, device): ...@@ -24,7 +24,7 @@ def train_step(batch, model, optimizer, device):
preds = model(input_image) preds = model(input_image)
recon_combined, recons, masks, slots = preds recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1) input_image = input_image.permute(0, 2, 3, 1)
loss_value = nn.MSELoss(recon_combined, input_image) loss_value = criterion(recon_combined, input_image)
del recons, masks, slots # Unused. del recons, masks, slots # Unused.
# Get and apply gradients. # Get and apply gradients.
...@@ -34,7 +34,7 @@ def train_step(batch, model, optimizer, device): ...@@ -34,7 +34,7 @@ def train_step(batch, model, optimizer, device):
return loss_value.item() return loss_value.item()
def eval_step(batch, model, device): def eval_step(batch, model, device, criterion):
"""Perform a single eval step.""" """Perform a single eval step."""
input_image = torch.squeeze(batch.get('input_images').to(device), dim=1) input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
input_image = F.interpolate(input_image, size=128) input_image = F.interpolate(input_image, size=128)
...@@ -43,7 +43,7 @@ def eval_step(batch, model, device): ...@@ -43,7 +43,7 @@ def eval_step(batch, model, device):
preds = model(input_image) preds = model(input_image)
recon_combined, recons, masks, slots = preds recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1) input_image = input_image.permute(0, 2, 3, 1)
loss_value = F.mse_loss(recon_combined, input_image) loss_value = criterion(recon_combined, input_image)
del recons, masks, slots # Unused. del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value) psnr = mse2psnr(loss_value)
...@@ -76,6 +76,7 @@ def main(): ...@@ -76,6 +76,7 @@ def main():
decay_rate = cfg["training"]["decay_rate"] decay_rate = cfg["training"]["decay_rate"]
decay_steps = cfg["training"]["decay_it"] decay_steps = cfg["training"]["decay_it"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.MSELoss()
resolution = (128, 128) resolution = (128, 128)
...@@ -148,7 +149,7 @@ def main(): ...@@ -148,7 +149,7 @@ def main():
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate param_group['lr'] = learning_rate
total_loss += train_step(batch, model, optimizer, device) total_loss += train_step(batch, model, optimizer, device, criterion)
global_step += 1 global_step += 1
total_loss /= len(train_loader) total_loss /= len(train_loader)
...@@ -176,7 +177,7 @@ def main(): ...@@ -176,7 +177,7 @@ def main():
val_psnr = 0 val_psnr = 0
model.eval() model.eval()
for batch in tqdm(val_loader): for batch in tqdm(val_loader):
mse, psnr = eval_step(batch, model, device) mse, psnr = eval_step(batch, model, device, criterion)
val_loss += mse val_loss += mse
val_psnr += psnr val_psnr += psnr
val_loss /= len(val_loader) val_loss /= len(val_loader)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment