From 1915d2fe8d95b6fb641b45957f44f73449193b50 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 24 Jul 2023 10:26:48 +0200 Subject: [PATCH] Fix loss issues --- train_sa.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/train_sa.py b/train_sa.py index 09e0c1e..14c0972 100644 --- a/train_sa.py +++ b/train_sa.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader import torch.nn.functional as F from tqdm import tqdm -def train_step(batch, model, optimizer, device): +def train_step(batch, model, optimizer, device, criterion): """Perform a single training step.""" input_image = torch.squeeze(batch.get('input_images').to(device), dim=1) input_image = F.interpolate(input_image, size=128) @@ -24,7 +24,7 @@ def train_step(batch, model, optimizer, device): preds = model(input_image) recon_combined, recons, masks, slots = preds input_image = input_image.permute(0, 2, 3, 1) - loss_value = nn.MSELoss(recon_combined, input_image) + loss_value = criterion(recon_combined, input_image) del recons, masks, slots # Unused. # Get and apply gradients. @@ -34,7 +34,7 @@ def train_step(batch, model, optimizer, device): return loss_value.item() -def eval_step(batch, model, device): +def eval_step(batch, model, device, criterion): """Perform a single eval step.""" input_image = torch.squeeze(batch.get('input_images').to(device), dim=1) input_image = F.interpolate(input_image, size=128) @@ -43,7 +43,7 @@ def eval_step(batch, model, device): preds = model(input_image) recon_combined, recons, masks, slots = preds input_image = input_image.permute(0, 2, 3, 1) - loss_value = F.mse_loss(recon_combined, input_image) + loss_value = criterion(recon_combined, input_image) del recons, masks, slots # Unused. psnr = mse2psnr(loss_value) @@ -76,6 +76,7 @@ def main(): decay_rate = cfg["training"]["decay_rate"] decay_steps = cfg["training"]["decay_it"] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + criterion = nn.MSELoss() resolution = (128, 128) @@ -148,7 +149,7 @@ def main(): for param_group in optimizer.param_groups: param_group['lr'] = learning_rate - total_loss += train_step(batch, model, optimizer, device) + total_loss += train_step(batch, model, optimizer, device, criterion) global_step += 1 total_loss /= len(train_loader) @@ -176,7 +177,7 @@ def main(): val_psnr = 0 model.eval() for batch in tqdm(val_loader): - mse, psnr = eval_step(batch, model, device) + mse, psnr = eval_step(batch, model, device, criterion) val_loss += mse val_psnr += psnr val_loss /= len(val_loader) -- GitLab