diff --git a/.visualisation_90000.png b/.visualisation_90000.png new file mode 100644 index 0000000000000000000000000000000000000000..6b5f312d0b2bacc14a66db604b930f9bf43aaad2 Binary files /dev/null and b/.visualisation_90000.png differ diff --git a/osrt/model.py b/osrt/model.py index 932abd9224e956af29d9eb1d05b426b77461e465..6c8850456d123c3714f816752f1a0df9b7442285 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -121,7 +121,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): masks = masks.softmax(dim = 1) recon_combined = (recons * masks).sum(dim = 1) - return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise else None + return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None def training_step(self, batch, batch_idx): """Perform a single training step.""" diff --git a/train_sa.py b/train_sa.py index 1cf29f6363dc0648576e786cb665013f136da107..79379b06054071619dc1fddbb3df246f51744b8c 100644 --- a/train_sa.py +++ b/train_sa.py @@ -81,8 +81,6 @@ def main(): trainer.fit(model, train_loader, val_loader) - - #### Create datasets vis_dataset = data.get_dataset('train', cfg['data']) vis_loader = DataLoader( diff --git a/visualize_sa.py b/visualize_sa.py index 881d9d474000ee763ed1877310a79ebf2c755119..557a8180e9d6c687c3db93b49d1c9882c26bee8c 100644 --- a/visualize_sa.py +++ b/visualize_sa.py @@ -52,7 +52,7 @@ def main(): #### Create model model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device) checkpoint = torch.load(args.ckpt) - + model.load_state_dict(checkpoint['state_dict']) model.eval()