From 96ec6825c664233969efd8d473bfe8229222ecf0 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Mon, 24 Jul 2023 15:44:25 +0200
Subject: [PATCH] Fix problem decoder

---
 osrt/model.py | 8 ++++----
 visualise.py  | 2 +-
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/osrt/model.py b/osrt/model.py
index 833dba6..29ac38e 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -73,10 +73,10 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
 
         self.decoder_initial_size = (8, 8)
         self.decoder_cnn = nn.Sequential(
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
             nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
             nn.ConvTranspose2d(64, 4, kernel_size=3)
         )
diff --git a/visualise.py b/visualise.py
index 677f46a..19d311b 100644
--- a/visualise.py
+++ b/visualise.py
@@ -66,7 +66,7 @@ def main():
     #ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth')
     """ckpt = torch.load('~/ckpt.pth')
     model = ckpt['network']"""
-    model.load_state_dict(torch.load('/home/achapin/ckpt.pth')["model_state_dict"])
+    model.load_state_dict(torch.load('/home/achapin/ckpt_1639.pth')["model_state_dict"])
 
     image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
     image = F.interpolate(image, size=128)
-- 
GitLab