diff --git a/.visualisation_0.png b/.visualisation_0.png
deleted file mode 100644
index 6f0184272777793b0d4181822f23ecf15a2b2a1d..0000000000000000000000000000000000000000
Binary files a/.visualisation_0.png and /dev/null differ
diff --git a/.visualisation_1.png b/.visualisation_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..75d42e018221d28f38312ca284c393c15372b045
Binary files /dev/null and b/.visualisation_1.png differ
diff --git a/Nonevisualisation_0.png b/Nonevisualisation_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..446572f1ca3373425fa03050a45a8e14dfe58b5e
Binary files /dev/null and b/Nonevisualisation_0.png differ
diff --git a/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0 b/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0
deleted file mode 100644
index ddc38a00a26221598c3c339d50d8e4418047e23a..0000000000000000000000000000000000000000
Binary files a/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0 and /dev/null differ
diff --git a/lightning_logs/version_0/hparams.yaml b/lightning_logs/version_0/hparams.yaml
deleted file mode 100644
index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000
--- a/lightning_logs/version_0/hparams.yaml
+++ /dev/null
@@ -1 +0,0 @@
-{}
diff --git a/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0 b/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0
deleted file mode 100644
index 98f994d92a84a7072f246483250046a883e74f30..0000000000000000000000000000000000000000
Binary files a/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0 and /dev/null differ
diff --git a/lightning_logs/version_1/hparams.yaml b/lightning_logs/version_1/hparams.yaml
deleted file mode 100644
index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000
--- a/lightning_logs/version_1/hparams.yaml
+++ /dev/null
@@ -1 +0,0 @@
-{}
diff --git a/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0 b/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0
deleted file mode 100644
index 9e9c05c6e308e39ca9c9298dc2b9b8a32a4b516b..0000000000000000000000000000000000000000
Binary files a/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0 and /dev/null differ
diff --git a/lightning_logs/version_2/hparams.yaml b/lightning_logs/version_2/hparams.yaml
deleted file mode 100644
index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000
--- a/lightning_logs/version_2/hparams.yaml
+++ /dev/null
@@ -1 +0,0 @@
-{}
diff --git a/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0 b/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0
deleted file mode 100644
index a7cc5deb3a254b64b177f15573a663f3ad9d0de6..0000000000000000000000000000000000000000
Binary files a/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0 and /dev/null differ
diff --git a/lightning_logs/version_3/hparams.yaml b/lightning_logs/version_3/hparams.yaml
deleted file mode 100644
index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000
--- a/lightning_logs/version_3/hparams.yaml
+++ /dev/null
@@ -1 +0,0 @@
-{}
diff --git a/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0 b/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0
deleted file mode 100644
index 65a1f5c0cfc4c7d12dea17678f70732aaaf2731c..0000000000000000000000000000000000000000
Binary files a/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0 and /dev/null differ
diff --git a/lightning_logs/version_4/hparams.yaml b/lightning_logs/version_4/hparams.yaml
deleted file mode 100644
index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000
--- a/lightning_logs/version_4/hparams.yaml
+++ /dev/null
@@ -1 +0,0 @@
-{}
diff --git a/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0 b/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0
deleted file mode 100644
index 73d3b9d0a715b4f3a4b30267a21646af46b3fabe..0000000000000000000000000000000000000000
Binary files a/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0 and /dev/null differ
diff --git a/lightning_logs/version_5/hparams.yaml b/lightning_logs/version_5/hparams.yaml
deleted file mode 100644
index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000
--- a/lightning_logs/version_5/hparams.yaml
+++ /dev/null
@@ -1 +0,0 @@
-{}
diff --git a/osrt/model.py b/osrt/model.py
index 6c8850456d123c3714f816752f1a0df9b7442285..24842c9d926bda12ab66cfb52bbe0d99c2fcd0e7 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -100,7 +100,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         masks = masks.softmax(dim = 1)
         recon_combined = (recons * masks).sum(dim = 1)
 
-        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
+        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None
     
     def configure_optimizers(self) -> Any:
         optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
@@ -108,7 +108,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
     
     def one_step(self, image):
         x = self.encoder(image)
-        
+        attn_shape = x.shape[-3:-1]
         slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
         x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1)
         x = self.decoder(x)
@@ -120,8 +120,9 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         recons, masks = x.split((3, 1), dim = 2)
         masks = masks.softmax(dim = 1)
         recon_combined = (recons * masks).sum(dim = 1)
-
-        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None
+        
+        
+        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, attn_shape) if attn_slotwise is not None else None
 
     def training_step(self, batch, batch_idx):
         """Perform a single training step."""
diff --git a/outputsvisualisation_2.png b/outputsvisualisation_2.png
new file mode 100644
index 0000000000000000000000000000000000000000..e2f2c455e52630ed1a800d7e02e6778ba25724c3
Binary files /dev/null and b/outputsvisualisation_2.png differ
diff --git a/outputsvisualisation_3.png b/outputsvisualisation_3.png
new file mode 100644
index 0000000000000000000000000000000000000000..892a39c54d7466ba2f2d678fd65bd5b7b2bf9795
Binary files /dev/null and b/outputsvisualisation_3.png differ
diff --git a/visualize_sa.py b/visualize_sa.py
index 557a8180e9d6c687c3db93b49d1c9882c26bee8c..2e4031d46c69a38c0a82a1451ac90f6f7b52be18 100644
--- a/visualize_sa.py
+++ b/visualize_sa.py
@@ -26,7 +26,7 @@ def main():
     parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
     parser.add_argument('--seed', type=int, default=0, help='Random seed.')
     parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
-    parser.add_argument('--output', type=str, default=".", help='Folder in which to save images')
+    parser.add_argument('--output', type=str, default="./outputs", help='Folder in which to save images')
     parser.add_argument('--step', type=int, default=".", help='Step of the model')
 
     args = parser.parse_args()