Skip to content
Snippets Groups Projects
Commit e2d5300d authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Test new pos encoding

parent 78bc3df2
No related branches found
No related tags found
No related merge requests found
......@@ -355,7 +355,7 @@ class TransformerSlotAttention(nn.Module):
batch_size, *axis = inputs.shape
device = inputs.device
inputs = self.norm_input(inputs)
#inputs = self.norm_input(inputs)
if self.randomize_initial_slots:
slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
......@@ -368,12 +368,12 @@ class TransformerSlotAttention(nn.Module):
pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = batch_size)
enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size)
inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1)
for i in range(self.depth):
cross_attn, cross_ff = self.cs_layers[i]
print(f"Shape slots {slots.shape} an inputs shape {inputs.shape}")
print(f"Shape inputs : {inputs}")
x = cross_attn(slots, z = inputs) + slots # Cross-attention + Residual
slots = cross_ff(x) + x # Feed-forward + Residual
......
outputs/visualisation_8000.png

47.9 KiB

outputs/visualisation_9000.png

76.8 KiB

......@@ -50,7 +50,7 @@ def main():
shuffle=True, worker_init_fn=data.worker_init_fn)
#### Create model
model = LitSlotAttentionAutoEncoder(resolution, 6, num_iterations, cfg=cfg).to(device)
model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device)
checkpoint = torch.load(args.ckpt)
model.load_state_dict(checkpoint['state_dict'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment