diff --git a/osrt/layers.py b/osrt/layers.py index f681910a4461948a637dbab4eff83b4b590c8da4..4e9f0d1fa345f72d1c50e17bc47417d4e9f062b7 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -347,121 +347,6 @@ class TransformerSlotAttention(nn.Module): return slots # [batch_size, num_slots, dim] - -def unstack_and_split(x, batch_size, num_channels=3): - """Unstack batch dimension and split into channels and alpha mask.""" - unstacked = x.view(batch_size, -1, *x.shape[1:]) - channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1) - return channels, masks - -def spatial_flatten(x): - return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1]) - -def spatial_broadcast(slots, resolution): - """Broadcast slot features to a 2D grid and collapse slot dimension.""" - # `slots` has shape: [batch_size, num_slots, slot_size]. - slots = slots.view(-1, slots.shape[-1])[:, None, None, :] - grid = slots.repeat(1, resolution[0], resolution[1], 1) - # `grid` has shape: [batch_size*num_slots, width, height, slot_size]. - return grid - -# TODO : adapt this model -class SlotAttentionAutoEncoder(nn.Module): - """ - Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings. - - Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py - - """ - - def __init__(self, resolution, num_slots, num_iterations): - """Builds the Slot Attention-based auto-encoder. - - Args: - resolution: Tuple of integers specifying width and height of input image. - num_slots: Number of slots in Slot Attention. - num_iterations: Number of iterations in Slot Attention. - """ - super().__init__() - self.resolution = resolution - self.num_slots = num_slots - self.num_iterations = num_iterations - - self.encoder_cnn = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=5, padding=2), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=5, padding=2), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=5, padding=2), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=5, padding=2), - nn.ReLU() - ) - - self.decoder_initial_size = (8, 8) - self.decoder_cnn = nn.Sequential( - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), - nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), - nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), - nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), - nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1, padding=2), - nn.ReLU(), - nn.ConvTranspose2d(64, 4, kernel_size=3, stride=1, padding=2) - ) - - self.encoder_pos = SoftPositionEmbed(64, self.resolution) - self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size) - - self.layer_norm = nn.LayerNorm(64) - self.mlp = nn.Sequential( - JaxLinear(64, 64), - nn.ReLU(), - JaxLinear(64, 64) - ) - - self.slot_attention = SlotAttention( - num_slots=self.num_slots, - slot_dim=64, - hidden_dim=128, - iters=self.num_iterations) - - def forward(self, image): - # `image` has shape: [batch_size, width, height, num_channels]. - - # Convolutional encoder with position embedding. - x = self.encoder_cnn(image) # CNN Backbone. - #x = self.encoder_pos(x) # Position embedding. - x = spatial_flatten(x) # Flatten spatial dimensions (treat image as set). - x = self.mlp(self.layer_norm(x)) # Feedforward network on set. - # `x` has shape: [batch_size, width*height, input_size]. - - # Slot Attention module. - slots = self.slot_attention(x) - # `slots` has shape: [batch_size, num_slots, slot_size]. - - # Spatial broadcast decoder. - x = spatial_broadcast(slots, self.decoder_initial_size) - # `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size]. - #x = self.decoder_pos(x) - x = self.decoder_cnn(x) - # `x` has shape: [batch_size*num_slots, width, height, num_channels+1]. - - # Undo combination of slot and batch dimension; split alpha masks. - recons, masks = unstack_and_split(x, batch_size=image.shape[0]) - # `recons` has shape: [batch_size, num_slots, width, height, num_channels]. - # `masks` has shape: [batch_size, num_slots, width, height, 1]. - - # Normalize alpha masks over slots. - masks = torch.softmax(masks, dim=1) - recon_combined = torch.sum(recons * masks, dim=1) # Recombine image. - # `recon_combined` has shape: [batch_size, width, height, num_channels]. - - return recon_combined, recons, masks, slots - def build_grid(resolution): ranges = [np.linspace(0., 1., num=res) for res in resolution] grid = np.meshgrid(*ranges, sparse=False, indexing="ij") @@ -469,7 +354,7 @@ def build_grid(resolution): grid = np.reshape(grid, [resolution[0], resolution[1], -1]) grid = np.expand_dims(grid, axis=0) grid = grid.astype(np.float32) - return np.concatenate([grid, 1.0 - grid], axis=-1) + return np.concatenate([grid, 1.0 - grid], axis=-1).transpose(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w] class SoftPositionEmbed(nn.Module): """Adds soft positional embedding with learnable projection. @@ -487,6 +372,4 @@ class SoftPositionEmbed(nn.Module): self.grid = build_grid(resolution) def forward(self, inputs): - print(inputs.shape) - print(self.dense(torch.tensor(self.grid).cuda()).shape) return inputs + self.dense(torch.tensor(self.grid).cuda()) \ No newline at end of file diff --git a/osrt/model.py b/osrt/model.py index 6480fa8fee535bb17709ae274679dad0927338df..45821355ff17b56ebf0769373ce5211de3999cbf 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -1,7 +1,10 @@ from torch import nn +import torch +import numpy as np from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder +from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed import osrt.layers as layers @@ -33,3 +36,120 @@ class OSRT(nn.Module): raise ValueError(f'Unknown decoder type: {decoder_type}') + +def unstack_and_split(x, batch_size, num_channels=3): + """Unstack batch dimension and split into channels and alpha mask.""" + unstacked = x.view(batch_size, -1, *x.shape[1:]) + channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1) + return channels, masks + +def spatial_flatten(x): + return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1]) + +def spatial_broadcast(slots, resolution): + """Broadcast slot features to a 2D grid and collapse slot dimension.""" + # `slots` has shape: [batch_size, num_slots, slot_size]. + slots = slots.view(-1, slots.shape[-1])[:, None, None, :] + grid = slots.repeat(1, resolution[0], resolution[1], 1) + # `grid` has shape: [batch_size*num_slots, width, height, slot_size]. + return grid + + +# TODO : adapt this model +class SlotAttentionAutoEncoder(nn.Module): + """ + Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings. + + Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py + + """ + + def __init__(self, resolution, num_slots, num_iterations): + """Builds the Slot Attention-based auto-encoder. + + Args: + resolution: Tuple of integers specifying width and height of input image. + num_slots: Number of slots in Slot Attention. + num_iterations: Number of iterations in Slot Attention. + """ + super().__init__() + self.resolution = resolution + self.num_slots = num_slots + self.num_iterations = num_iterations + + self.encoder_cnn = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=5, padding=2), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=5, padding=2), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=5, padding=2), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=5, padding=2), + nn.ReLU() + ) + + self.decoder_initial_size = (8, 8) + self.decoder_cnn = nn.Sequential( + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), + nn.ReLU(), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), + nn.ReLU(), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), + nn.ReLU(), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), + nn.ReLU(), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1, padding=2), + nn.ReLU(), + nn.ConvTranspose2d(64, 4, kernel_size=3, stride=1, padding=2) + ) + + self.encoder_pos = SoftPositionEmbed(64, (32, 32)) + self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size) + + self.layer_norm = nn.LayerNorm(64) + self.mlp = nn.Sequential( + JaxLinear(64, 64), + nn.ReLU(), + JaxLinear(64, 64) + ) + + self.slot_attention = SlotAttention( + num_slots=self.num_slots, + slot_dim=64, + hidden_dim=128, + iters=self.num_iterations) + + def forward(self, image): + # `image` has shape: [batch_size, num_channels, width, height]. + print(f"Shape input {image.shape}") + # Convolutional encoder with position embedding. + x = self.encoder_cnn(image) # CNN Backbone. + print(f"Shape after encoder {x.shape}") + x = self.encoder_pos(x) # Position embedding. + x = spatial_flatten(x) # Flatten spatial dimensions (treat image as set). + x = self.mlp(self.layer_norm(x)) # Feedforward network on set. + # `x` has shape: [batch_size, width*height, input_size]. + + # Slot Attention module. + slots = self.slot_attention(x) + # `slots` has shape: [batch_size, num_slots, slot_size]. + + # Spatial broadcast decoder. + x = spatial_broadcast(slots, self.decoder_initial_size) + # `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size]. + #x = self.decoder_pos(x) + x = self.decoder_cnn(x) + # `x` has shape: [batch_size*num_slots, width, height, num_channels+1]. + + # Undo combination of slot and batch dimension; split alpha masks. + recons, masks = unstack_and_split(x, batch_size=image.shape[0]) + # `recons` has shape: [batch_size, num_slots, width, height, num_channels]. + # `masks` has shape: [batch_size, num_slots, width, height, 1]. + + # Normalize alpha masks over slots. + masks = torch.softmax(masks, dim=1) + recon_combined = torch.sum(recons * masks, dim=1) # Recombine image. + # `recon_combined` has shape: [batch_size, width, height, num_channels]. + + return recon_combined, recons, masks, slots + diff --git a/train_sa.py b/train_sa.py index c087f342d00020730875aed50ad2bab29b1a5750..3afbecd956ba117968bda77e3b69fdde9f6e1eab 100644 --- a/train_sa.py +++ b/train_sa.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.optim as optim import argparse import yaml -from osrt.layers import SlotAttentionAutoEncoder +from osrt.model import SlotAttentionAutoEncoder from osrt import data from torch.utils.data import DataLoader