Skip to content
Snippets Groups Projects
model.py 6.06 KiB
from torch import nn
import torch
import numpy as np

from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed

import osrt.layers as layers

class OSRT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        encoder_type = cfg['encoder']
        decoder_type = cfg['decoder']

        layers.__USE_DEFAULT_INIT__ = cfg.get('use_default_init', False)

        if encoder_type == 'srt':
            self.encoder = ImprovedSRTEncoder(**cfg['encoder_kwargs'])
        elif encoder_type == 'osrt':
            self.encoder = OSRTEncoder(**cfg['encoder_kwargs'])
        elif encoder_type == 'sam':
            self.encoder = FeatureMasking(**cfg['encoder_kwargs'])
        else:
            raise ValueError(f'Unknown encoder type: {encoder_type}')


        if decoder_type == 'spatial_broadcast':
            self.decoder = SpatialBroadcastDecoder(**cfg['decoder_kwargs'])
        elif decoder_type == 'srt':
            self.decoder = ImprovedSRTDecoder(**cfg['decoder_kwargs'])
        elif decoder_type == 'slot_mixer':
            self.decoder = SlotMixerDecoder(**cfg['decoder_kwargs'])
        else:
            raise ValueError(f'Unknown decoder type: {decoder_type}')



def unstack_and_split(x, batch_size, num_channels=3):
    """Unstack batch dimension and split into channels and alpha mask."""
    unstacked = x.view(batch_size, -1, *x.shape[1:])
    channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1)
    return channels, masks

def spatial_flatten(x):
    return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1])

def spatial_broadcast(slots, resolution):
    """Broadcast slot features to a 2D grid and collapse slot dimension."""
    # `slots` has shape: [batch_size, num_slots, slot_size].
    slots = slots.view(-1, slots.shape[-1])[:, None, None, :]
    grid = slots.repeat(1, resolution[0], resolution[1], 1)
    # `grid` has shape: [batch_size*num_slots, width, height, slot_size].
    return grid

class SlotAttentionAutoEncoder(nn.Module):
    """
    Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.

    Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
    """

    def __init__(self, resolution, num_slots, num_iterations):
        """Builds the Slot Attention-based auto-encoder.

        Args:
        resolution: Tuple of integers specifying width and height of input image.
        num_slots: Number of slots in Slot Attention.
        num_iterations: Number of iterations in Slot Attention.
        """
        super().__init__()
        self.resolution = resolution
        self.num_slots = num_slots
        self.num_iterations = num_iterations

        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU()
        )

        self.decoder_initial_size = (8, 8)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(1, 1), padding=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1)
        )

        self.encoder_pos = SoftPositionEmbed(64, self.resolution)
        self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size)

        self.layer_norm = nn.LayerNorm(64)
        self.mlp = nn.Sequential(
            JaxLinear(64, 64),
            nn.ReLU(),
            JaxLinear(64, 64)
        )

        self.slot_attention = SlotAttention(
            num_slots=self.num_slots,
            input_dim=64,
            slot_dim=64,
            hidden_dim=128,
            iters=self.num_iterations)

    def forward(self, image):
        # `image` has shape: [batch_size, num_channels, width, height].
        # Convolutional encoder with position embedding.
        x = self.encoder_cnn(image)  # CNN Backbone.
        x = self.encoder_pos(x).permute(0, 2, 3, 1)  # Position embedding.
        x = spatial_flatten(x)  # Flatten spatial dimensions (treat image as set).
        x = self.mlp(self.layer_norm(x))  # Feedforward network on set.
        # `x` has shape: [batch_size, width*height, input_size].

        # Slot Attention module.
        slots = self.slot_attention(x)
        # `slots` has shape: [batch_size, num_slots, slot_size].

        # Spatial broadcast decoder.
        x = spatial_broadcast(slots, self.decoder_initial_size).permute(0, 3, 1, 2)

        # `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
        x = self.decoder_pos(x)
        x = self.decoder_cnn(x).permute(0, 2, 3, 1)
        # `x` has shape: [batch_size*num_slots, width, height, num_channels+1].

        # Undo combination of slot and batch dimension; split alpha masks.
        recons, masks = unstack_and_split(x, batch_size=image.shape[0])
        # `recons` has shape: [batch_size, num_slots, width, height, num_channels].
        # `masks` has shape: [batch_size, num_slots, width, height, 1].

        # Normalize alpha masks over slots.
        masks = torch.softmax(masks, dim=1)
        recon_combined = torch.sum(recons * masks, dim=1)  # Recombine image.
        # `recon_combined` has shape: [batch_size, width, height, num_channels].

        return recon_combined, recons, masks, slots