-
Alexandre Chapin authored6d55e14b
model.py 6.06 KiB
from torch import nn
import torch
import numpy as np
from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed
import osrt.layers as layers
class OSRT(nn.Module):
def __init__(self, cfg):
super().__init__()
encoder_type = cfg['encoder']
decoder_type = cfg['decoder']
layers.__USE_DEFAULT_INIT__ = cfg.get('use_default_init', False)
if encoder_type == 'srt':
self.encoder = ImprovedSRTEncoder(**cfg['encoder_kwargs'])
elif encoder_type == 'osrt':
self.encoder = OSRTEncoder(**cfg['encoder_kwargs'])
elif encoder_type == 'sam':
self.encoder = FeatureMasking(**cfg['encoder_kwargs'])
else:
raise ValueError(f'Unknown encoder type: {encoder_type}')
if decoder_type == 'spatial_broadcast':
self.decoder = SpatialBroadcastDecoder(**cfg['decoder_kwargs'])
elif decoder_type == 'srt':
self.decoder = ImprovedSRTDecoder(**cfg['decoder_kwargs'])
elif decoder_type == 'slot_mixer':
self.decoder = SlotMixerDecoder(**cfg['decoder_kwargs'])
else:
raise ValueError(f'Unknown decoder type: {decoder_type}')
def unstack_and_split(x, batch_size, num_channels=3):
"""Unstack batch dimension and split into channels and alpha mask."""
unstacked = x.view(batch_size, -1, *x.shape[1:])
channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1)
return channels, masks
def spatial_flatten(x):
return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1])
def spatial_broadcast(slots, resolution):
"""Broadcast slot features to a 2D grid and collapse slot dimension."""
# `slots` has shape: [batch_size, num_slots, slot_size].
slots = slots.view(-1, slots.shape[-1])[:, None, None, :]
grid = slots.repeat(1, resolution[0], resolution[1], 1)
# `grid` has shape: [batch_size*num_slots, width, height, slot_size].
return grid
class SlotAttentionAutoEncoder(nn.Module):
"""
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
"""
def __init__(self, resolution, num_slots, num_iterations):
"""Builds the Slot Attention-based auto-encoder.
Args:
resolution: Tuple of integers specifying width and height of input image.
num_slots: Number of slots in Slot Attention.
num_iterations: Number of iterations in Slot Attention.
"""
super().__init__()
self.resolution = resolution
self.num_slots = num_slots
self.num_iterations = num_iterations
self.encoder_cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU()
)
self.decoder_initial_size = (8, 8)
self.decoder_cnn = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(1, 1), padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1)
)
self.encoder_pos = SoftPositionEmbed(64, self.resolution)
self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size)
self.layer_norm = nn.LayerNorm(64)
self.mlp = nn.Sequential(
JaxLinear(64, 64),
nn.ReLU(),
JaxLinear(64, 64)
)
self.slot_attention = SlotAttention(
num_slots=self.num_slots,
input_dim=64,
slot_dim=64,
hidden_dim=128,
iters=self.num_iterations)
def forward(self, image):
# `image` has shape: [batch_size, num_channels, width, height].
# Convolutional encoder with position embedding.
x = self.encoder_cnn(image) # CNN Backbone.
x = self.encoder_pos(x).permute(0, 2, 3, 1) # Position embedding.
x = spatial_flatten(x) # Flatten spatial dimensions (treat image as set).
x = self.mlp(self.layer_norm(x)) # Feedforward network on set.
# `x` has shape: [batch_size, width*height, input_size].
# Slot Attention module.
slots = self.slot_attention(x)
# `slots` has shape: [batch_size, num_slots, slot_size].
# Spatial broadcast decoder.
x = spatial_broadcast(slots, self.decoder_initial_size).permute(0, 3, 1, 2)
# `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
x = self.decoder_pos(x)
x = self.decoder_cnn(x).permute(0, 2, 3, 1)
# `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
# Undo combination of slot and batch dimension; split alpha masks.
recons, masks = unstack_and_split(x, batch_size=image.shape[0])
# `recons` has shape: [batch_size, num_slots, width, height, num_channels].
# `masks` has shape: [batch_size, num_slots, width, height, 1].
# Normalize alpha masks over slots.
masks = torch.softmax(masks, dim=1)
recon_combined = torch.sum(recons * masks, dim=1) # Recombine image.
# `recon_combined` has shape: [batch_size, width, height, num_channels].
return recon_combined, recons, masks, slots