diff --git a/osrt/model.py b/osrt/model.py index 833dba682b5e589114ac6d771251816a676cccb0..29ac38ef8d4c26ea05a0df2e3b3530d88ba9d0ff 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -73,10 +73,10 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): self.decoder_initial_size = (8, 8) self.decoder_cnn = nn.Sequential( - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 4, kernel_size=3) ) diff --git a/visualise.py b/visualise.py index 677f46ad352f9cf2700ce046b525b7bf740e036a..19d311b6cd7703d72366b3447465518d659c3775 100644 --- a/visualise.py +++ b/visualise.py @@ -66,7 +66,7 @@ def main(): #ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth') """ckpt = torch.load('~/ckpt.pth') model = ckpt['network']""" - model.load_state_dict(torch.load('/home/achapin/ckpt.pth')["model_state_dict"]) + model.load_state_dict(torch.load('/home/achapin/ckpt_1639.pth')["model_state_dict"]) image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1) image = F.interpolate(image, size=128)