Skip to content
Snippets Groups Projects
Commit 30c62cd9 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Handle problem with attention_slotwise

parent 6563a107
No related branches found
No related tags found
No related merge requests found
Showing
with 6 additions and 11 deletions
.visualisation_0.png

34.8 KiB

.visualisation_1.png

40.1 KiB

Nonevisualisation_0.png

41.1 KiB

File deleted
{}
File deleted
{}
File deleted
{}
File deleted
{}
File deleted
{}
File deleted
{}
......@@ -100,7 +100,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
masks = masks.softmax(dim = 1)
recon_combined = (recons * masks).sum(dim = 1)
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None
def configure_optimizers(self) -> Any:
optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
......@@ -108,7 +108,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
def one_step(self, image):
x = self.encoder(image)
attn_shape = x.shape[-3:-1]
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1)
x = self.decoder(x)
......@@ -120,8 +120,9 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
recons, masks = x.split((3, 1), dim = 2)
masks = masks.softmax(dim = 1)
recon_combined = (recons * masks).sum(dim = 1)
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, attn_shape) if attn_slotwise is not None else None
def training_step(self, batch, batch_idx):
"""Perform a single training step."""
......
outputsvisualisation_2.png

47.4 KiB

outputsvisualisation_3.png

46.1 KiB

......@@ -26,7 +26,7 @@ def main():
parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
parser.add_argument('--output', type=str, default=".", help='Folder in which to save images')
parser.add_argument('--output', type=str, default="./outputs", help='Folder in which to save images')
parser.add_argument('--step', type=int, default=".", help='Step of the model')
args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment