From 96ec6825c664233969efd8d473bfe8229222ecf0 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 24 Jul 2023 15:44:25 +0200 Subject: [PATCH] Fix problem decoder --- osrt/model.py | 8 ++++---- visualise.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/osrt/model.py b/osrt/model.py index 833dba6..29ac38e 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -73,10 +73,10 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): self.decoder_initial_size = (8, 8) self.decoder_cnn = nn.Sequential( - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 4, kernel_size=3) ) diff --git a/visualise.py b/visualise.py index 677f46a..19d311b 100644 --- a/visualise.py +++ b/visualise.py @@ -66,7 +66,7 @@ def main(): #ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth') """ckpt = torch.load('~/ckpt.pth') model = ckpt['network']""" - model.load_state_dict(torch.load('/home/achapin/ckpt.pth')["model_state_dict"]) + model.load_state_dict(torch.load('/home/achapin/ckpt_1639.pth')["model_state_dict"]) image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1) image = F.interpolate(image, size=128) -- GitLab