diff --git a/osrt/layers.py b/osrt/layers.py index 6662e554a5fb2d151ea60c7cd0cf830266af2b74..24c363f866c6b003ff4c1e297b0e84a2bbd90258 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -327,9 +327,9 @@ class TransformerSlotAttention(nn.Module): ### Cross-attention layers self.cs_layers = nn.ModuleList([]) for _ in range(depth): - # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None): + # The +26 value is due to the positional encoding added self.cs_layers.append(nn.ModuleList([ - PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False), self.input_dim), + PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim+26, selfatt=False), self.input_dim + 26), PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim)) ])) @@ -337,8 +337,8 @@ class TransformerSlotAttention(nn.Module): self.sf_layers = nn.ModuleList([]) for _ in range(depth-1): self.sf_layers.append(nn.ModuleList([ - PreNorm(self.input_dim, Attention(self.input_dim, heads=self.self_head, dim_head = self.hidden_dim)), - PreNorm(self.input_dim, FeedForward(self.input_dim, self.hidden_dim)) + PreNorm(self.input_dim+26, Attention(self.input_dim+26, heads=self.self_head, dim_head = self.hidden_dim)), + PreNorm(self.input_dim+26, FeedForward(self.input_dim+26, self.hidden_dim)) ])) ### Initialize slots @@ -354,10 +354,6 @@ class TransformerSlotAttention(nn.Module): """ batch_size, *axis = inputs.shape device = inputs.device - print(f"Shape inputs first : {inputs.shape}") - - inputs = self.norm_input(inputs) - print(f"Shape inputs after norm : {inputs.shape}") if self.randomize_initial_slots: slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim] @@ -370,11 +366,10 @@ class TransformerSlotAttention(nn.Module): pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1) enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') - enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size) + enc_pos = repeat(enc_pos, '... -> b ... ', b = batch_size) # Size = 26 + inputs = torch.cat((inputs, enc_pos.reshape(batch_size,-1,enc_pos.shape[-1])), dim = -1) - inputs = self.norm_input(inputs) - print(f"Shape inputs after encoding : {inputs.shape}") for i in range(self.depth): cross_attn, cross_ff = self.cs_layers[i]