Skip to content
Snippets Groups Projects
Commit 2322dc27 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix value error because of positional encoding

parent e6e1600b
No related branches found
No related tags found
No related merge requests found
...@@ -327,9 +327,9 @@ class TransformerSlotAttention(nn.Module): ...@@ -327,9 +327,9 @@ class TransformerSlotAttention(nn.Module):
### Cross-attention layers ### Cross-attention layers
self.cs_layers = nn.ModuleList([]) self.cs_layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
# def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None): # The +26 value is due to the positional encoding added
self.cs_layers.append(nn.ModuleList([ self.cs_layers.append(nn.ModuleList([
PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False), self.input_dim), PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim+26, selfatt=False), self.input_dim + 26),
PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim)) PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim))
])) ]))
...@@ -337,8 +337,8 @@ class TransformerSlotAttention(nn.Module): ...@@ -337,8 +337,8 @@ class TransformerSlotAttention(nn.Module):
self.sf_layers = nn.ModuleList([]) self.sf_layers = nn.ModuleList([])
for _ in range(depth-1): for _ in range(depth-1):
self.sf_layers.append(nn.ModuleList([ self.sf_layers.append(nn.ModuleList([
PreNorm(self.input_dim, Attention(self.input_dim, heads=self.self_head, dim_head = self.hidden_dim)), PreNorm(self.input_dim+26, Attention(self.input_dim+26, heads=self.self_head, dim_head = self.hidden_dim)),
PreNorm(self.input_dim, FeedForward(self.input_dim, self.hidden_dim)) PreNorm(self.input_dim+26, FeedForward(self.input_dim+26, self.hidden_dim))
])) ]))
### Initialize slots ### Initialize slots
...@@ -354,10 +354,6 @@ class TransformerSlotAttention(nn.Module): ...@@ -354,10 +354,6 @@ class TransformerSlotAttention(nn.Module):
""" """
batch_size, *axis = inputs.shape batch_size, *axis = inputs.shape
device = inputs.device device = inputs.device
print(f"Shape inputs first : {inputs.shape}")
inputs = self.norm_input(inputs)
print(f"Shape inputs after norm : {inputs.shape}")
if self.randomize_initial_slots: if self.randomize_initial_slots:
slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim] slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
...@@ -370,11 +366,10 @@ class TransformerSlotAttention(nn.Module): ...@@ -370,11 +366,10 @@ class TransformerSlotAttention(nn.Module):
pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1) pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size) enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size) # Size = 26
inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1) inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1)
inputs = self.norm_input(inputs)
print(f"Shape inputs after encoding : {inputs.shape}")
for i in range(self.depth): for i in range(self.depth):
cross_attn, cross_ff = self.cs_layers[i] cross_attn, cross_ff = self.cs_layers[i]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment