diff --git a/osrt/model.py b/osrt/model.py index 4b04c5b0ee3fd8d523554e9bcdec6e1a7460a002..d176d9abd4d115b3c6268625988327e9ada1f67b 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -131,7 +131,6 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) def configure_optimizers(self) -> Any: - print(self.parameters()) optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) return optimizer diff --git a/osrt/utils/visualization_utils.py b/osrt/utils/visualization_utils.py index 93f8c4e7013dacc84d2d9ba86d77d518a4e8c30a..6e36fafe5039d8abf7a76931366649be8a6c5599 100644 --- a/osrt/utils/visualization_utils.py +++ b/osrt/utils/visualization_utils.py @@ -102,7 +102,6 @@ def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, fo # Extract data and put it on a plot ax[0].imshow(image) ax[0].set_title('Image') - print(image) ax[1].imshow((recon_combined * 255).astype(np.uint8)) ax[1].set_title('Recon.') for i in range(6): diff --git a/outputs/visualisation_0.png b/outputs/visualisation_0.png new file mode 100644 index 0000000000000000000000000000000000000000..8eb5eed64e59cbda6ac42889cee7c9e41013c5bc Binary files /dev/null and b/outputs/visualisation_0.png differ diff --git a/outputs/visualisation_3000.png b/outputs/visualisation_3000.png new file mode 100644 index 0000000000000000000000000000000000000000..a7484acd5e6e7e899d3446e27cbe7ccb0780734a Binary files /dev/null and b/outputs/visualisation_3000.png differ diff --git a/.visualisation_1639.png b/outputs/visualisation_6000.png similarity index 100% rename from .visualisation_1639.png rename to outputs/visualisation_6000.png diff --git a/visualize_sa.py b/visualize_sa.py index 9153006d27592d4d0961fc60f246f29d60505ef6..5d4c67f67a093fafe0a9f9fcc2450e10d84fa4a9 100644 --- a/visualize_sa.py +++ b/visualize_sa.py @@ -26,6 +26,8 @@ def main(): parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--seed', type=int, default=0, help='Random seed.') parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') + parser.add_argument('--output', type=str, default=".", help='Folder in which to save images') + parser.add_argument('--step', type=int, default=".", help='Step of the model') args = parser.parse_args() with open(args.config, 'r') as f: @@ -63,7 +65,8 @@ def main(): loss = nn.MSELoss() loss_value = loss(recon_combined, image) psnr = mse2psnr(loss_value) - visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True) + print(f"MSE {loss_value} and PSNR {psnr}") + visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.output, step=args.step, save_file=True) if __name__ == "__main__": main() \ No newline at end of file