Skip to content
Snippets Groups Projects
Commit 735336c1 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix issue with model

parent efe36191
No related branches found
No related tags found
No related merge requests found
......@@ -129,7 +129,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
def configure_optimizers(self) -> Any:
optimizer = optim.Adam(self.parameters, lr=1e-3, eps=1e-08)
optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
return optimizer
def one_step(self, image):
......@@ -137,7 +137,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x))
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots)
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment