Skip to content
Snippets Groups Projects
Commit 96ec6825 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix problem decoder

parent 1a36bb26
No related branches found
No related tags found
No related merge requests found
......@@ -73,10 +73,10 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
self.decoder_initial_size = (8, 8)
self.decoder_cnn = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 4, kernel_size=3)
)
......
......@@ -66,7 +66,7 @@ def main():
#ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth')
"""ckpt = torch.load('~/ckpt.pth')
model = ckpt['network']"""
model.load_state_dict(torch.load('/home/achapin/ckpt.pth')["model_state_dict"])
model.load_state_dict(torch.load('/home/achapin/ckpt_1639.pth')["model_state_dict"])
image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
image = F.interpolate(image, size=128)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment