Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • achapin/Segment-Object-Centric
1 result
Show changes
Commits on Source (2)
......@@ -103,8 +103,8 @@ class RayEncoder(nn.Module):
class PreNorm(nn.Module):
def __init__(self, dim, fn, cross_dim=None):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_cross = nn.LayerNorm(cross_dim) if cross_dim is not None else None
def forward(self, x, **kwargs):
......@@ -112,7 +112,7 @@ class PreNorm(nn.Module):
if self.norm_cross is not None:
z = kwargs['z']
normed_context = self.norm_cross(z)
kwargs.update(cross_val = normed_context)
kwargs.update(z = normed_context)
return self.fn(x, **kwargs)
......@@ -373,7 +373,8 @@ class TransformerSlotAttention(nn.Module):
for i in range(self.depth):
cross_attn, cross_ff = self.cs_layers[i]
x = cross_attn(slots, inputs) + slots # Cross-attention + Residual
print(f"Shape slots {slots.shape} an inputs shape {inputs.shape}")
x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual
slots = cross_ff(x) + x # Feed-forward + Residual
## Apply self-attention on input tokens but only before last depth layer
......
......@@ -131,7 +131,6 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
def configure_optimizers(self) -> Any:
print(self.parameters())
optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
return optimizer
......
......@@ -102,7 +102,6 @@ def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, fo
# Extract data and put it on a plot
ax[0].imshow(image)
ax[0].set_title('Image')
print(image)
ax[1].imshow((recon_combined * 255).astype(np.uint8))
ax[1].set_title('Recon.')
for i in range(6):
......
outputs/visualisation_0.png

82.9 KiB

outputs/visualisation_3000.png

91 KiB

......@@ -26,6 +26,8 @@ def main():
parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
parser.add_argument('--output', type=str, default=".", help='Folder in which to save images')
parser.add_argument('--step', type=int, default=".", help='Step of the model')
args = parser.parse_args()
with open(args.config, 'r') as f:
......@@ -63,7 +65,8 @@ def main():
loss = nn.MSELoss()
loss_value = loss(recon_combined, image)
psnr = mse2psnr(loss_value)
visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True)
print(f"MSE {loss_value} and PSNR {psnr}")
visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.output, step=args.step, save_file=True)
if __name__ == "__main__":
main()
\ No newline at end of file