Skip to content
Snippets Groups Projects
Commit 29d94b3f authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Verify positional encoding

parent f8daecee
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@ import torch.nn.init as init
import numpy as np
import math
from einops import rearrange
from einops import rearrange, repeat
__USE_DEFAULT_INIT__ = False
......@@ -295,6 +295,19 @@ class SoftPositionEmbed(nn.Module):
def forward(self, inputs):
return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w]
def fourier_encode(x, max_freq, num_bands = 4):
x = x.unsqueeze(-1)
device, dtype, orig_x = x.device, x.dtype, x
scales = torch.linspace(1., max_freq / 2, num_bands, device = device, dtype = dtype)
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
x = x * scales * math.pi
x = torch.cat([x.sin(), x.cos()], dim = -1)
x = torch.cat((x, orig_x), dim = -1)
return x
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
class TransformerSlotAttention(nn.Module):
"""
......@@ -313,6 +326,8 @@ class TransformerSlotAttention(nn.Module):
self.depth = depth
self.self_head = self_heads
self.cross_heads=cross_heads
self.max_freq = 10
self.num_freq_bands = 6
### Cross-attention layers
self.cs_layers = nn.ModuleList([])
......@@ -335,12 +350,15 @@ class TransformerSlotAttention(nn.Module):
self.randomize_initial_slots = randomize_initial_slots
self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim))
self.norm_input = nn.LayerNorm(input_dim)
def forward(self, inputs):
"""
Args:
inputs: set-latent representation [batch_size, num_inputs, dim]
"""
batch_size, num_inputs, dim = inputs.shape
batch_size, *axis = inputs.shape
device = inputs.device
inputs = self.norm_input(inputs)
......@@ -350,18 +368,23 @@ class TransformerSlotAttention(nn.Module):
else:
slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device)
############### TODO : adapt this part of code
# data = torch.cat((data, enc_pos.reshape(b,-1,enc_pos.shape[-1])), dim = -1) TODO : add a positional encoding here
############### TODO : check positional encoding
axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), (int(np.sqrt(axis[0])),int(np.sqrt(axis[0])))))
pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = batch_size)
inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1)
for i in range(self.depth):
cross_attn, cross_ff = self.cs_layers[i]
x = cross_attn(slots, data) + slots # Cross-attention + Residual
x = cross_attn(slots, inputs) + slots # Cross-attention + Residual
slots = cross_ff(x) + x # Feed-forward + Residual
## Apply self-attention on input tokens but only before last depth layer
if i != self.depth - 1:
self_attn, self_ff = self.sf_layers[i]
x_d = self_attn(data) + data
data = self_ff(x_d) + x_d
x_d = self_attn(inputs) + inputs
inputs = self_ff(x_d) + x_d
return slots # [batch_size, num_slots, dim]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment