diff --git a/.visualisation_0.png b/.visualisation_0.png deleted file mode 100644 index 6f0184272777793b0d4181822f23ecf15a2b2a1d..0000000000000000000000000000000000000000 Binary files a/.visualisation_0.png and /dev/null differ diff --git a/.visualisation_1.png b/.visualisation_1.png new file mode 100644 index 0000000000000000000000000000000000000000..75d42e018221d28f38312ca284c393c15372b045 Binary files /dev/null and b/.visualisation_1.png differ diff --git a/Nonevisualisation_0.png b/Nonevisualisation_0.png new file mode 100644 index 0000000000000000000000000000000000000000..446572f1ca3373425fa03050a45a8e14dfe58b5e Binary files /dev/null and b/Nonevisualisation_0.png differ diff --git a/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0 b/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0 deleted file mode 100644 index ddc38a00a26221598c3c339d50d8e4418047e23a..0000000000000000000000000000000000000000 Binary files a/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0 and /dev/null differ diff --git a/lightning_logs/version_0/hparams.yaml b/lightning_logs/version_0/hparams.yaml deleted file mode 100644 index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000 --- a/lightning_logs/version_0/hparams.yaml +++ /dev/null @@ -1 +0,0 @@ -{} diff --git a/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0 b/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0 deleted file mode 100644 index 98f994d92a84a7072f246483250046a883e74f30..0000000000000000000000000000000000000000 Binary files a/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0 and /dev/null differ diff --git a/lightning_logs/version_1/hparams.yaml b/lightning_logs/version_1/hparams.yaml deleted file mode 100644 index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000 --- a/lightning_logs/version_1/hparams.yaml +++ /dev/null @@ -1 +0,0 @@ -{} diff --git a/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0 b/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0 deleted file mode 100644 index 9e9c05c6e308e39ca9c9298dc2b9b8a32a4b516b..0000000000000000000000000000000000000000 Binary files a/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0 and /dev/null differ diff --git a/lightning_logs/version_2/hparams.yaml b/lightning_logs/version_2/hparams.yaml deleted file mode 100644 index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000 --- a/lightning_logs/version_2/hparams.yaml +++ /dev/null @@ -1 +0,0 @@ -{} diff --git a/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0 b/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0 deleted file mode 100644 index a7cc5deb3a254b64b177f15573a663f3ad9d0de6..0000000000000000000000000000000000000000 Binary files a/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0 and /dev/null differ diff --git a/lightning_logs/version_3/hparams.yaml b/lightning_logs/version_3/hparams.yaml deleted file mode 100644 index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000 --- a/lightning_logs/version_3/hparams.yaml +++ /dev/null @@ -1 +0,0 @@ -{} diff --git a/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0 b/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0 deleted file mode 100644 index 65a1f5c0cfc4c7d12dea17678f70732aaaf2731c..0000000000000000000000000000000000000000 Binary files a/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0 and /dev/null differ diff --git a/lightning_logs/version_4/hparams.yaml b/lightning_logs/version_4/hparams.yaml deleted file mode 100644 index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000 --- a/lightning_logs/version_4/hparams.yaml +++ /dev/null @@ -1 +0,0 @@ -{} diff --git a/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0 b/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0 deleted file mode 100644 index 73d3b9d0a715b4f3a4b30267a21646af46b3fabe..0000000000000000000000000000000000000000 Binary files a/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0 and /dev/null differ diff --git a/lightning_logs/version_5/hparams.yaml b/lightning_logs/version_5/hparams.yaml deleted file mode 100644 index 0967ef424bce6791893e9a57bb952f80fd536e93..0000000000000000000000000000000000000000 --- a/lightning_logs/version_5/hparams.yaml +++ /dev/null @@ -1 +0,0 @@ -{} diff --git a/osrt/model.py b/osrt/model.py index 6c8850456d123c3714f816752f1a0df9b7442285..24842c9d926bda12ab66cfb52bbe0d99c2fcd0e7 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -100,7 +100,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): masks = masks.softmax(dim = 1) recon_combined = (recons * masks).sum(dim = 1) - return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) + return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None def configure_optimizers(self) -> Any: optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) @@ -108,7 +108,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): def one_step(self, image): x = self.encoder(image) - + attn_shape = x.shape[-3:-1] slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2)) x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1) x = self.decoder(x) @@ -120,8 +120,9 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): recons, masks = x.split((3, 1), dim = 2) masks = masks.softmax(dim = 1) recon_combined = (recons * masks).sum(dim = 1) - - return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None + + + return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, attn_shape) if attn_slotwise is not None else None def training_step(self, batch, batch_idx): """Perform a single training step.""" diff --git a/outputsvisualisation_2.png b/outputsvisualisation_2.png new file mode 100644 index 0000000000000000000000000000000000000000..e2f2c455e52630ed1a800d7e02e6778ba25724c3 Binary files /dev/null and b/outputsvisualisation_2.png differ diff --git a/outputsvisualisation_3.png b/outputsvisualisation_3.png new file mode 100644 index 0000000000000000000000000000000000000000..892a39c54d7466ba2f2d678fd65bd5b7b2bf9795 Binary files /dev/null and b/outputsvisualisation_3.png differ diff --git a/visualize_sa.py b/visualize_sa.py index 557a8180e9d6c687c3db93b49d1c9882c26bee8c..2e4031d46c69a38c0a82a1451ac90f6f7b52be18 100644 --- a/visualize_sa.py +++ b/visualize_sa.py @@ -26,7 +26,7 @@ def main(): parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--seed', type=int, default=0, help='Random seed.') parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') - parser.add_argument('--output', type=str, default=".", help='Folder in which to save images') + parser.add_argument('--output', type=str, default="./outputs", help='Folder in which to save images') parser.add_argument('--step', type=int, default=".", help='Step of the model') args = parser.parse_args()