Skip to content
Snippets Groups Projects
layers.py 13.64 KiB
import torch
import torch.nn as nn
import torch.nn.init as init
import numpy as np

import math
from einops import rearrange


__USE_DEFAULT_INIT__ = False


class JaxLinear(nn.Linear):
    """ Linear layers with initialization matching the Jax defaults """
    def reset_parameters(self):
        if __USE_DEFAULT_INIT__:
            super().reset_parameters()
        else:
            input_size = self.weight.shape[-1]
            std = math.sqrt(1/input_size)
            init.trunc_normal_(self.weight, std=std, a=-2.*std, b=2.*std)
            if self.bias is not None:
                init.zeros_(self.bias)


class ViTLinear(nn.Linear):
    """ Initialization for linear layers used by ViT """
    def reset_parameters(self):
        if __USE_DEFAULT_INIT__:
            super().reset_parameters()
        else:
            init.xavier_uniform_(self.weight)
            if self.bias is not None:
                init.normal_(self.bias, std=1e-6)


class SRTLinear(nn.Linear):
    """ Initialization for linear layers used in the SRT decoder """
    def reset_parameters(self):
        if __USE_DEFAULT_INIT__:
            super().reset_parameters()
        else:
            init.xavier_uniform_(self.weight)
            if self.bias is not None:
                init.zeros_(self.bias)


class PositionalEncoding(nn.Module):
    def __init__(self, num_octaves=8, start_octave=0):
        super().__init__()
        self.num_octaves = num_octaves
        self.start_octave = start_octave

    def forward(self, coords, rays=None):
        embed_fns = []
        batch_size, num_points, dim = coords.shape

        octaves = torch.arange(self.start_octave, self.start_octave + self.num_octaves)
        octaves = octaves.float().to(coords)
        multipliers = 2**octaves * math.pi
        coords = coords.unsqueeze(-1) 
        while len(multipliers.shape) < len(coords.shape):
            multipliers = multipliers.unsqueeze(0)

        scaled_coords = coords * multipliers

        sines = torch.sin(scaled_coords).reshape(batch_size, num_points, dim * self.num_octaves)
        cosines = torch.cos(scaled_coords).reshape(batch_size, num_points, dim * self.num_octaves)

        result = torch.cat((sines, cosines), -1)
        return result


class RayEncoder(nn.Module):
    def __init__(self, pos_octaves=8, pos_start_octave=0, ray_octaves=4, ray_start_octave=0):
        super().__init__()
        self.pos_encoding = PositionalEncoding(num_octaves=pos_octaves, start_octave=pos_start_octave)
        self.ray_encoding = PositionalEncoding(num_octaves=ray_octaves, start_octave=ray_start_octave)

    def forward(self, pos, rays):
        if len(rays.shape) == 4:
            batchsize, height, width, _ = rays.shape
            pos_enc = self.pos_encoding(pos.unsqueeze(1))
            pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1)
            pos_enc = pos_enc.repeat(1, 1, height, width)

            rays = rays.flatten(1, 2)
            ray_enc = self.ray_encoding(rays)
            ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1])
            ray_enc = ray_enc.permute((0, 3, 1, 2))
            x = torch.cat((pos_enc, ray_enc), 1)
        else:
            pos_enc = self.pos_encoding(pos)
            ray_enc = self.ray_encoding(rays)
            x = torch.cat((pos_enc, ray_enc), -1)

        return x

# Transformer implementation based on ViT
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            ViTLinear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            ViTLinear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        if selfatt:
            self.to_qkv = JaxLinear(dim, inner_dim * 3, bias=False)
        else:
            self.to_q = JaxLinear(dim, inner_dim, bias=False)
            self.to_kv = JaxLinear(kv_dim, inner_dim * 2, bias=False)

        self.to_out = nn.Sequential(
            JaxLinear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, z=None):
        if z is None:
            qkv = self.to_qkv(x).chunk(3, dim=-1)
        else:
            q = self.to_q(x)
            k, v = self.to_kv(z).chunk(2, dim=-1)
            qkv = (q, k, v)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0., selfatt=True, kv_dim=None):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head,
                                       dropout=dropout, selfatt=selfatt, kv_dim=kv_dim)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x, z=None):
        for attn, ff in self.layers:
            x = attn(x, z=z) + x
            x = ff(x) + x
        return x


class SlotAttention(nn.Module):
    """
    Slot Attention as introduced by Locatello et al.

    @edit : we changed the code as to make it possible to handle a different number of slots depending on the input images
    """
    def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8,
                 randomize_initial_slots=False):
        super().__init__()

        self.num_slots = num_slots
        self.batch_slots = []
        self.iters = iters
        self.scale = slot_dim ** -0.5
        self.slot_dim = slot_dim

        self.randomize_initial_slots = randomize_initial_slots
        self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim))

        self.eps = eps

        self.to_q = JaxLinear(slot_dim, slot_dim, bias=False)
        self.to_k = JaxLinear(input_dim, slot_dim, bias=False)
        self.to_v = JaxLinear(input_dim, slot_dim, bias=False)

        self.gru = nn.GRUCell(slot_dim, slot_dim)

        self.mlp = nn.Sequential(
            JaxLinear(slot_dim, hidden_dim),
            nn.ReLU(inplace=True),
            JaxLinear(hidden_dim, slot_dim)
        )

        self.norm_input   = nn.LayerNorm(input_dim)
        self.norm_slots   = nn.LayerNorm(slot_dim)
        self.norm_pre_mlp = nn.LayerNorm(slot_dim)

    def forward(self, inputs, masks=None):
        """
        Args:
            inputs: set-latent representation [batch_size, num_inputs, dim]
        """
        batch_size, num_inputs, dim = inputs.shape

        inputs = self.norm_input(inputs)
        
        # Initialize the slots. Shape: [batch_size, num_slots, slot_dim].
        if self.randomize_initial_slots:
            slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
            slots = torch.distributions.Normal(slot_means, self.embedding_stdev).rsample()
        else:
            slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device)

        k, v = self.to_k(inputs), self.to_v(inputs)

        # Multiple rounds of attention.
        for _ in range(self.iters):
            slots_prev = slots
            norm_slots = self.norm_slots(slots)

            q = self.to_q(norm_slots)

            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale # Dot product and normalization

            if masks != None:
                temp_masks = masks.unsqueeze(1)
                attention_masking = torch.where(temp_masks == 1.0, float("-inf"), temp_masks).to(device=dots.device)
                dots += attention_masking
            # shape: [batch_size, num_slots, num_inputs]
            attn = dots.softmax(dim=1) + self.eps

            # Weighted mean
            attn = attn / attn.sum(dim=-1, keepdim=True)
            updates = torch.einsum('bjd,bij->bid', v, attn) # shape: [batch_size, num_inputs, slot_dim] 
            
            # Slot update
            slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1))
            slots = slots.reshape(batch_size, self.num_slots, self.slot_dim)
            slots = slots + self.mlp(self.norm_pre_mlp(slots))

        return slots # [batch_size, num_slots, dim]

    def change_slots_number(self, num_slots): 
        self.num_slots = num_slots
        self.initial_slots = nn.Parameter(torch.randn(num_slots, self.slot_dim))

### Utils for SlotAttentionAutoEncoder
def build_grid(resolution):
    ranges = [np.linspace(0., 1., num=res) for res in resolution]
    grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
    grid = np.stack(grid, axis=-1)
    grid = np.reshape(grid, [resolution[0], resolution[1], -1])
    grid = np.expand_dims(grid, axis=0)
    grid = grid.astype(np.float32)
    return np.concatenate([grid, 1.0 - grid], axis=-1)

class SoftPositionEmbed(nn.Module):
    """Adds soft positional embedding with learnable projection.
    Implementation extracted from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py"""

    def __init__(self, hidden_size, resolution):
        """Builds the soft position embedding layer.

        Args:
            hidden_size: Size of input feature dimension.
            resolution: Tuple of integers specifying width and height of grid.
        """
        super().__init__()
        self.dense = JaxLinear(4, hidden_size)
        self.grid = build_grid(resolution)

    def forward(self, inputs):
        return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) #  from [b, h, w, c] to [b, c, h, w]
    
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
class TransformerSlotAttention(nn.Module):
    """
    An extension of Slot Attention using self-attention inspired from work done in "Visual Concepts Tokenization" from Yang et al. 2022
    """
    def __init__(self, num_slots=10, depth=6, input_dim=768, slot_dim=1536, hidden_dim=3072, cross_heads=1, self_heads=6,
                 randomize_initial_slots=False):
        super().__init__()

        self.num_slots = num_slots
        self.input_dim = input_dim
        self.batch_slots = []
        self.scale = slot_dim ** -0.5
        self.slot_dim = slot_dim # latent_dim
        self.hidden_dim = hidden_dim
        self.depth = depth
        self.self_head = self_heads
        self.cross_heads=cross_heads

        ### Cross-attention layers
        self.cs_layers = nn.ModuleList([])
        for _ in range(depth):
            # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
            self.cs_layers.append(nn.ModuleList([
                PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, kv_dim=self.input_dim, selfatt=False)),
                PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim))
            ]))

        ### Self-attention layers
        self.sf_layers = nn.ModuleList([])
        for _ in range(depth-1):
            self.sf_layers.append(nn.ModuleList([
                PreNorm(self.input_dim, Attention(self.input_dim, heads=self.self_head, dim_head = self.hidden_dim)),
                PreNorm(self.input_dim, FeedForward(self.input_dim, self.hidden_dim))
            ]))

        ### Initialize slots
        self.randomize_initial_slots = randomize_initial_slots
        self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim))

    def forward(self, inputs):
        """
        Args:
            inputs: set-latent representation [batch_size, num_inputs, dim]
        """
        batch_size, num_inputs, dim = inputs.shape

        inputs = self.norm_input(inputs)
        
        if self.randomize_initial_slots:
            slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
            slots = torch.distributions.Normal(slot_means, self.embedding_stdev).rsample()
        else:
            slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device)

        ############### TODO : adapt this part of code    
        # data = torch.cat((data, enc_pos.reshape(b,-1,enc_pos.shape[-1])), dim = -1) TODO : add a positional encoding here

        for i in range(self.depth):
            cross_attn, cross_ff = self.cs_layers[i]
            x = cross_attn(slots, data) + slots # Cross-attention + Residual
            slots = cross_ff(x) + x # Feed-forward + Residual

            ## Apply self-attention on input tokens but only before last depth layer
            if i != self.depth - 1:
                self_attn, self_ff = self.sf_layers[i]
                x_d = self_attn(data) + data
                data = self_ff(x_d) + x_d

        return slots # [batch_size, num_slots, dim]