Skip to content
Snippets Groups Projects
Commit b79778fe authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix issue with pos encode slot att

parent 803079e6
No related branches found
No related tags found
No related merge requests found
...@@ -347,121 +347,6 @@ class TransformerSlotAttention(nn.Module): ...@@ -347,121 +347,6 @@ class TransformerSlotAttention(nn.Module):
return slots # [batch_size, num_slots, dim] return slots # [batch_size, num_slots, dim]
def unstack_and_split(x, batch_size, num_channels=3):
"""Unstack batch dimension and split into channels and alpha mask."""
unstacked = x.view(batch_size, -1, *x.shape[1:])
channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1)
return channels, masks
def spatial_flatten(x):
return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1])
def spatial_broadcast(slots, resolution):
"""Broadcast slot features to a 2D grid and collapse slot dimension."""
# `slots` has shape: [batch_size, num_slots, slot_size].
slots = slots.view(-1, slots.shape[-1])[:, None, None, :]
grid = slots.repeat(1, resolution[0], resolution[1], 1)
# `grid` has shape: [batch_size*num_slots, width, height, slot_size].
return grid
# TODO : adapt this model
class SlotAttentionAutoEncoder(nn.Module):
"""
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
"""
def __init__(self, resolution, num_slots, num_iterations):
"""Builds the Slot Attention-based auto-encoder.
Args:
resolution: Tuple of integers specifying width and height of input image.
num_slots: Number of slots in Slot Attention.
num_iterations: Number of iterations in Slot Attention.
"""
super().__init__()
self.resolution = resolution
self.num_slots = num_slots
self.num_iterations = num_iterations
self.encoder_cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU()
)
self.decoder_initial_size = (8, 8)
self.decoder_cnn = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 4, kernel_size=3, stride=1, padding=2)
)
self.encoder_pos = SoftPositionEmbed(64, self.resolution)
self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size)
self.layer_norm = nn.LayerNorm(64)
self.mlp = nn.Sequential(
JaxLinear(64, 64),
nn.ReLU(),
JaxLinear(64, 64)
)
self.slot_attention = SlotAttention(
num_slots=self.num_slots,
slot_dim=64,
hidden_dim=128,
iters=self.num_iterations)
def forward(self, image):
# `image` has shape: [batch_size, width, height, num_channels].
# Convolutional encoder with position embedding.
x = self.encoder_cnn(image) # CNN Backbone.
#x = self.encoder_pos(x) # Position embedding.
x = spatial_flatten(x) # Flatten spatial dimensions (treat image as set).
x = self.mlp(self.layer_norm(x)) # Feedforward network on set.
# `x` has shape: [batch_size, width*height, input_size].
# Slot Attention module.
slots = self.slot_attention(x)
# `slots` has shape: [batch_size, num_slots, slot_size].
# Spatial broadcast decoder.
x = spatial_broadcast(slots, self.decoder_initial_size)
# `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
#x = self.decoder_pos(x)
x = self.decoder_cnn(x)
# `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
# Undo combination of slot and batch dimension; split alpha masks.
recons, masks = unstack_and_split(x, batch_size=image.shape[0])
# `recons` has shape: [batch_size, num_slots, width, height, num_channels].
# `masks` has shape: [batch_size, num_slots, width, height, 1].
# Normalize alpha masks over slots.
masks = torch.softmax(masks, dim=1)
recon_combined = torch.sum(recons * masks, dim=1) # Recombine image.
# `recon_combined` has shape: [batch_size, width, height, num_channels].
return recon_combined, recons, masks, slots
def build_grid(resolution): def build_grid(resolution):
ranges = [np.linspace(0., 1., num=res) for res in resolution] ranges = [np.linspace(0., 1., num=res) for res in resolution]
grid = np.meshgrid(*ranges, sparse=False, indexing="ij") grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
...@@ -469,7 +354,7 @@ def build_grid(resolution): ...@@ -469,7 +354,7 @@ def build_grid(resolution):
grid = np.reshape(grid, [resolution[0], resolution[1], -1]) grid = np.reshape(grid, [resolution[0], resolution[1], -1])
grid = np.expand_dims(grid, axis=0) grid = np.expand_dims(grid, axis=0)
grid = grid.astype(np.float32) grid = grid.astype(np.float32)
return np.concatenate([grid, 1.0 - grid], axis=-1) return np.concatenate([grid, 1.0 - grid], axis=-1).transpose(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w]
class SoftPositionEmbed(nn.Module): class SoftPositionEmbed(nn.Module):
"""Adds soft positional embedding with learnable projection. """Adds soft positional embedding with learnable projection.
...@@ -487,6 +372,4 @@ class SoftPositionEmbed(nn.Module): ...@@ -487,6 +372,4 @@ class SoftPositionEmbed(nn.Module):
self.grid = build_grid(resolution) self.grid = build_grid(resolution)
def forward(self, inputs): def forward(self, inputs):
print(inputs.shape)
print(self.dense(torch.tensor(self.grid).cuda()).shape)
return inputs + self.dense(torch.tensor(self.grid).cuda()) return inputs + self.dense(torch.tensor(self.grid).cuda())
\ No newline at end of file
from torch import nn from torch import nn
import torch
import numpy as np
from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed
import osrt.layers as layers import osrt.layers as layers
...@@ -33,3 +36,120 @@ class OSRT(nn.Module): ...@@ -33,3 +36,120 @@ class OSRT(nn.Module):
raise ValueError(f'Unknown decoder type: {decoder_type}') raise ValueError(f'Unknown decoder type: {decoder_type}')
def unstack_and_split(x, batch_size, num_channels=3):
"""Unstack batch dimension and split into channels and alpha mask."""
unstacked = x.view(batch_size, -1, *x.shape[1:])
channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1)
return channels, masks
def spatial_flatten(x):
return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1])
def spatial_broadcast(slots, resolution):
"""Broadcast slot features to a 2D grid and collapse slot dimension."""
# `slots` has shape: [batch_size, num_slots, slot_size].
slots = slots.view(-1, slots.shape[-1])[:, None, None, :]
grid = slots.repeat(1, resolution[0], resolution[1], 1)
# `grid` has shape: [batch_size*num_slots, width, height, slot_size].
return grid
# TODO : adapt this model
class SlotAttentionAutoEncoder(nn.Module):
"""
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
"""
def __init__(self, resolution, num_slots, num_iterations):
"""Builds the Slot Attention-based auto-encoder.
Args:
resolution: Tuple of integers specifying width and height of input image.
num_slots: Number of slots in Slot Attention.
num_iterations: Number of iterations in Slot Attention.
"""
super().__init__()
self.resolution = resolution
self.num_slots = num_slots
self.num_iterations = num_iterations
self.encoder_cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU()
)
self.decoder_initial_size = (8, 8)
self.decoder_cnn = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 4, kernel_size=3, stride=1, padding=2)
)
self.encoder_pos = SoftPositionEmbed(64, (32, 32))
self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size)
self.layer_norm = nn.LayerNorm(64)
self.mlp = nn.Sequential(
JaxLinear(64, 64),
nn.ReLU(),
JaxLinear(64, 64)
)
self.slot_attention = SlotAttention(
num_slots=self.num_slots,
slot_dim=64,
hidden_dim=128,
iters=self.num_iterations)
def forward(self, image):
# `image` has shape: [batch_size, num_channels, width, height].
print(f"Shape input {image.shape}")
# Convolutional encoder with position embedding.
x = self.encoder_cnn(image) # CNN Backbone.
print(f"Shape after encoder {x.shape}")
x = self.encoder_pos(x) # Position embedding.
x = spatial_flatten(x) # Flatten spatial dimensions (treat image as set).
x = self.mlp(self.layer_norm(x)) # Feedforward network on set.
# `x` has shape: [batch_size, width*height, input_size].
# Slot Attention module.
slots = self.slot_attention(x)
# `slots` has shape: [batch_size, num_slots, slot_size].
# Spatial broadcast decoder.
x = spatial_broadcast(slots, self.decoder_initial_size)
# `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
#x = self.decoder_pos(x)
x = self.decoder_cnn(x)
# `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
# Undo combination of slot and batch dimension; split alpha masks.
recons, masks = unstack_and_split(x, batch_size=image.shape[0])
# `recons` has shape: [batch_size, num_slots, width, height, num_channels].
# `masks` has shape: [batch_size, num_slots, width, height, 1].
# Normalize alpha masks over slots.
masks = torch.softmax(masks, dim=1)
recon_combined = torch.sum(recons * masks, dim=1) # Recombine image.
# `recon_combined` has shape: [batch_size, width, height, num_channels].
return recon_combined, recons, masks, slots
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import argparse import argparse
import yaml import yaml
from osrt.layers import SlotAttentionAutoEncoder from osrt.model import SlotAttentionAutoEncoder
from osrt import data from osrt import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment