diff --git a/.visualisation_90000.png b/.visualisation_90000.png
new file mode 100644
index 0000000000000000000000000000000000000000..6b5f312d0b2bacc14a66db604b930f9bf43aaad2
Binary files /dev/null and b/.visualisation_90000.png differ
diff --git a/osrt/model.py b/osrt/model.py
index 932abd9224e956af29d9eb1d05b426b77461e465..6c8850456d123c3714f816752f1a0df9b7442285 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -121,7 +121,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         masks = masks.softmax(dim = 1)
         recon_combined = (recons * masks).sum(dim = 1)
 
-        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise else None
+        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None
 
     def training_step(self, batch, batch_idx):
         """Perform a single training step."""
diff --git a/train_sa.py b/train_sa.py
index 1cf29f6363dc0648576e786cb665013f136da107..79379b06054071619dc1fddbb3df246f51744b8c 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -81,8 +81,6 @@ def main():
 
     trainer.fit(model, train_loader, val_loader)
 
-
-
     #### Create datasets
     vis_dataset = data.get_dataset('train', cfg['data'])
     vis_loader = DataLoader(
diff --git a/visualize_sa.py b/visualize_sa.py
index 881d9d474000ee763ed1877310a79ebf2c755119..557a8180e9d6c687c3db93b49d1c9882c26bee8c 100644
--- a/visualize_sa.py
+++ b/visualize_sa.py
@@ -52,7 +52,7 @@ def main():
     #### Create model
     model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device)
     checkpoint = torch.load(args.ckpt)
-    
+
     model.load_state_dict(checkpoint['state_dict'])
 
     model.eval()