Skip to content
Snippets Groups Projects
Commit 735336c1 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix issue with model

parent efe36191
No related branches found
No related tags found
No related merge requests found
...@@ -129,7 +129,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): ...@@ -129,7 +129,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
def configure_optimizers(self) -> Any: def configure_optimizers(self) -> Any:
optimizer = optim.Adam(self.parameters, lr=1e-3, eps=1e-08) optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
return optimizer return optimizer
def one_step(self, image): def one_step(self, image):
...@@ -137,7 +137,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): ...@@ -137,7 +137,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
x = self.encoder_pos(x) x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x)) x = self.mlp(self.layer_norm(x))
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots) slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1) x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
x = self.decoder_pos(x) x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1)) x = self.decoder_cnn(x.movedim(-1, 1))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment