From 1915d2fe8d95b6fb641b45957f44f73449193b50 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Mon, 24 Jul 2023 10:26:48 +0200
Subject: [PATCH] Fix loss issues

---
 train_sa.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/train_sa.py b/train_sa.py
index 09e0c1e..14c0972 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
 import torch.nn.functional as F
 from tqdm import tqdm
 
-def train_step(batch, model, optimizer, device):
+def train_step(batch, model, optimizer, device, criterion):
     """Perform a single training step."""
     input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
     input_image = F.interpolate(input_image, size=128)
@@ -24,7 +24,7 @@ def train_step(batch, model, optimizer, device):
     preds = model(input_image)
     recon_combined, recons, masks, slots = preds
     input_image = input_image.permute(0, 2, 3, 1)
-    loss_value = nn.MSELoss(recon_combined, input_image)
+    loss_value = criterion(recon_combined, input_image)
     del recons, masks, slots  # Unused.
 
     # Get and apply gradients.
@@ -34,7 +34,7 @@ def train_step(batch, model, optimizer, device):
 
     return loss_value.item()
 
-def eval_step(batch, model, device):
+def eval_step(batch, model, device, criterion):
     """Perform a single eval step."""
     input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
     input_image = F.interpolate(input_image, size=128)
@@ -43,7 +43,7 @@ def eval_step(batch, model, device):
     preds = model(input_image)
     recon_combined, recons, masks, slots = preds
     input_image = input_image.permute(0, 2, 3, 1)
-    loss_value = F.mse_loss(recon_combined, input_image)
+    loss_value = criterion(recon_combined, input_image)
     del recons, masks, slots  # Unused.
     psnr = mse2psnr(loss_value)
 
@@ -76,6 +76,7 @@ def main():
     decay_rate = cfg["training"]["decay_rate"]
     decay_steps = cfg["training"]["decay_it"]
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    criterion = nn.MSELoss()
 
     resolution = (128, 128)
     
@@ -148,7 +149,7 @@ def main():
             for param_group in optimizer.param_groups:
                 param_group['lr'] = learning_rate
 
-            total_loss += train_step(batch, model, optimizer, device)
+            total_loss += train_step(batch, model, optimizer, device, criterion)
             global_step += 1
 
         total_loss /= len(train_loader)
@@ -176,7 +177,7 @@ def main():
             val_psnr = 0
             model.eval()
             for batch in tqdm(val_loader):
-                mse, psnr = eval_step(batch, model, device)
+                mse, psnr = eval_step(batch, model, device, criterion)
                 val_loss += mse
                 val_psnr += psnr
             val_loss /= len(val_loader)
-- 
GitLab