Skip to content
Snippets Groups Projects
Commit 6563a107 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Resolve issue on output model

parent 42faaee0
No related branches found
No related tags found
No related merge requests found
.visualisation_90000.png

33 KiB

...@@ -121,7 +121,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): ...@@ -121,7 +121,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
masks = masks.softmax(dim = 1) masks = masks.softmax(dim = 1)
recon_combined = (recons * masks).sum(dim = 1) recon_combined = (recons * masks).sum(dim = 1)
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise else None return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
"""Perform a single training step.""" """Perform a single training step."""
......
...@@ -81,8 +81,6 @@ def main(): ...@@ -81,8 +81,6 @@ def main():
trainer.fit(model, train_loader, val_loader) trainer.fit(model, train_loader, val_loader)
#### Create datasets #### Create datasets
vis_dataset = data.get_dataset('train', cfg['data']) vis_dataset = data.get_dataset('train', cfg['data'])
vis_loader = DataLoader( vis_loader = DataLoader(
......
...@@ -52,7 +52,7 @@ def main(): ...@@ -52,7 +52,7 @@ def main():
#### Create model #### Create model
model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device) model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device)
checkpoint = torch.load(args.ckpt) checkpoint = torch.load(args.ckpt)
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
model.eval() model.eval()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment