diff --git a/osrt/layers.py b/osrt/layers.py index aadbdf50c6356436ab672f9f9c4f756ab1cbb445..d74ff57f8b4fd7bc7da7ca96666ba60700c26b33 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -355,7 +355,7 @@ class TransformerSlotAttention(nn.Module): batch_size, *axis = inputs.shape device = inputs.device - inputs = self.norm_input(inputs) + #inputs = self.norm_input(inputs) if self.randomize_initial_slots: slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim] @@ -368,12 +368,12 @@ class TransformerSlotAttention(nn.Module): pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1) enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') - enc_pos = repeat(enc_pos, '... -> b ...', b = batch_size) + enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size) inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1) for i in range(self.depth): cross_attn, cross_ff = self.cs_layers[i] - print(f"Shape slots {slots.shape} an inputs shape {inputs.shape}") + print(f"Shape inputs : {inputs}") x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual slots = cross_ff(x) + x # Feed-forward + Residual diff --git a/outputs/visualisation_8000.png b/outputs/visualisation_8000.png new file mode 100644 index 0000000000000000000000000000000000000000..84ae3fad1a4d8968c3bd89ebb65bbba276509a3b Binary files /dev/null and b/outputs/visualisation_8000.png differ diff --git a/outputs/visualisation_9000.png b/outputs/visualisation_9000.png new file mode 100644 index 0000000000000000000000000000000000000000..55588a785d95b9e18267c5003bd6813c6550f3a6 Binary files /dev/null and b/outputs/visualisation_9000.png differ diff --git a/visualize_sa.py b/visualize_sa.py index 5d4c67f67a093fafe0a9f9fcc2450e10d84fa4a9..a1d4ce15204654a4018d62ecc25e34a6341c7a72 100644 --- a/visualize_sa.py +++ b/visualize_sa.py @@ -50,7 +50,7 @@ def main(): shuffle=True, worker_init_fn=data.worker_init_fn) #### Create model - model = LitSlotAttentionAutoEncoder(resolution, 6, num_iterations, cfg=cfg).to(device) + model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device) checkpoint = torch.load(args.ckpt) model.load_state_dict(checkpoint['state_dict'])