Skip to content
Snippets Groups Projects
Commit 672cd717 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Make new model slot attention

parent 78417e2e
No related branches found
No related tags found
No related merge requests found
.visualisation_1639.png

77.6 KiB

...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import math import math
from einops import rearrange, repeat from einops import rearrange, repeat
import torch.nn.functional as F
__USE_DEFAULT_INIT__ = False __USE_DEFAULT_INIT__ = False
...@@ -194,10 +195,11 @@ class SlotAttention(nn.Module): ...@@ -194,10 +195,11 @@ class SlotAttention(nn.Module):
@edit : we changed the code as to make it possible to handle a different number of slots depending on the input images @edit : we changed the code as to make it possible to handle a different number of slots depending on the input images
""" """
def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8, def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8,
randomize_initial_slots=False): randomize_initial_slots=False, gain = 1, temperature_factor = 1):
super().__init__() super().__init__()
self.num_slots = num_slots self.num_slots = num_slots
self.temperature_factor = temperature_factor
self.batch_slots = [] self.batch_slots = []
self.iters = iters self.iters = iters
self.scale = slot_dim ** -0.5 self.scale = slot_dim ** -0.5
...@@ -207,24 +209,31 @@ class SlotAttention(nn.Module): ...@@ -207,24 +209,31 @@ class SlotAttention(nn.Module):
self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim)) self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim))
self.eps = eps self.eps = eps
self.slots_mu = nn.Parameter(nn.init.xavier_uniform_(torch.empty(1, 1, self.slot_dim)))
self.slots_log_sigma = nn.Parameter(nn.init.xavier_uniform_(torch.empty(1, 1, self.slot_dim)))
self.to_q = nn.Linear(slot_dim, slot_dim, bias=False)
self.to_k = nn.Linear(input_dim, slot_dim, bias=False)
self.to_v = nn.Linear(input_dim, slot_dim, bias=False)
self.to_q = JaxLinear(slot_dim, slot_dim, bias=False) nn.init.xavier_uniform_(self.to_q.weight, gain = gain)
self.to_k = JaxLinear(input_dim, slot_dim, bias=False) nn.init.xavier_uniform_(self.to_k.weight, gain = gain)
self.to_v = JaxLinear(input_dim, slot_dim, bias=False) nn.init.xavier_uniform_(self.to_v.weight, gain = gain)
self.gru = nn.GRUCell(slot_dim, slot_dim) self.gru = nn.GRUCell(slot_dim, slot_dim)
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
JaxLinear(slot_dim, hidden_dim), nn.Linear(slot_dim, hidden_dim),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
JaxLinear(hidden_dim, slot_dim) nn.Linear(hidden_dim, slot_dim)
) )
self.norm_input = nn.LayerNorm(input_dim) self.norm_input = nn.LayerNorm(input_dim)
self.norm_slots = nn.LayerNorm(slot_dim) self.norm_slots = nn.LayerNorm(slot_dim)
self.norm_pre_mlp = nn.LayerNorm(slot_dim) self.norm_pre_mlp = nn.LayerNorm(slot_dim)
def forward(self, inputs, masks=None): def forward(self, inputs):
""" """
Args: Args:
inputs: set-latent representation [batch_size, num_inputs, dim] inputs: set-latent representation [batch_size, num_inputs, dim]
...@@ -232,74 +241,56 @@ class SlotAttention(nn.Module): ...@@ -232,74 +241,56 @@ class SlotAttention(nn.Module):
batch_size, num_inputs, dim = inputs.shape batch_size, num_inputs, dim = inputs.shape
inputs = self.norm_input(inputs) inputs = self.norm_input(inputs)
# Initialize the slots. Shape: [batch_size, num_slots, slot_dim].
if self.randomize_initial_slots:
slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
slots = torch.distributions.Normal(slot_means, self.embedding_stdev).rsample()
else:
slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device)
k, v = self.to_k(inputs), self.to_v(inputs) k, v = self.to_k(inputs), self.to_v(inputs)
if slots is None:
slots = self.slots_mu + torch.exp(self.slots_log_sigma) * torch.randn(len(inputs), self.num_slots, self.slot_size, device = self.slots_mu.device)
# Multiple rounds of attention. # Multiple rounds of attention.
for _ in range(self.iters): for _ in range(self.iters):
slots_prev = slots slots_prev = slots
norm_slots = self.norm_slots(slots) slots = self.norm_slots(slots)
q = self.to_q(norm_slots) q = self.to_q(slots)
q *= self.scale
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale # Dot product and normalization attn_logits = torch.bmm(q, k.transpose(-1, -2))
attn_pixelwise = F.softmax(attn_logits / self.temperature_factor, dim = 1)
if masks != None:
temp_masks = masks.unsqueeze(1)
attention_masking = torch.where(temp_masks == 1.0, float("-inf"), temp_masks).to(device=dots.device)
dots += attention_masking
# shape: [batch_size, num_slots, num_inputs] # shape: [batch_size, num_slots, num_inputs]
attn = dots.softmax(dim=1) + self.eps attn_slotwise = F.normalize(attn_pixelwise + self.eps, p = 1, dim = -1)
# Weighted mean # shape: [batch_size, num_inputs, slot_dim]
attn = attn / attn.sum(dim=-1, keepdim=True) updates = torch.bmm(attn_slotwise, v)
updates = torch.einsum('bjd,bij->bid', v, attn) # shape: [batch_size, num_inputs, slot_dim]
# Slot update # Slot update
slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1)) slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1))
slots = slots.reshape(batch_size, self.num_slots, self.slot_dim) slots = slots.reshape(batch_size, self.num_slots, self.slot_dim)
slots = slots + self.mlp(self.norm_pre_mlp(slots)) slots = slots + self.mlp(self.norm_pre_mlp(slots))
return slots # [batch_size, num_slots, dim] return slots, attn_logits, attn_slotwise # [batch_size, num_slots, dim]
def change_slots_number(self, num_slots): def change_slots_number(self, num_slots):
self.num_slots = num_slots self.num_slots = num_slots
self.initial_slots = nn.Parameter(torch.randn(num_slots, self.slot_dim)) self.initial_slots = nn.Parameter(torch.randn(num_slots, self.slot_dim))
### Utils for SlotAttentionAutoEncoder
def build_grid(resolution):
ranges = [np.linspace(0., 1., num=res) for res in resolution]
grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
grid = np.stack(grid, axis=-1)
grid = np.reshape(grid, [resolution[0], resolution[1], -1])
grid = np.expand_dims(grid, axis=0)
grid = grid.astype(np.float32)
return np.concatenate([grid, 1.0 - grid], axis=-1)
class SoftPositionEmbed(nn.Module):
"""Adds soft positional embedding with learnable projection.
Implementation extracted from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py"""
def __init__(self, hidden_size, resolution):
"""Builds the soft position embedding layer.
Args: class PositionEmbeddingImplicit(nn.Module):
hidden_size: Size of input feature dimension. """
resolution: Tuple of integers specifying width and height of grid. Position embedding extracted from
""" https://github.com/vadimkantorov/yet_another_pytorch_slot_attention/blob/master/models.py
"""
def __init__(self, hidden_dim):
super().__init__() super().__init__()
self.dense = JaxLinear(4, hidden_size) self.dense = nn.Linear(4, hidden_dim)
self.grid = build_grid(resolution)
def forward(self, inputs): def forward(self, x):
return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w] spatial_shape = x.shape[-3:-1]
grid = torch.stack(torch.meshgrid(*[torch.linspace(0., 1., r, device = x.device) for r in spatial_shape]), dim = -1)
grid = torch.cat([grid, 1 - grid], dim = -1)
return x + self.dense(grid)
def fourier_encode(x, max_freq, num_bands = 4): def fourier_encode(x, max_freq, num_bands = 4):
x = x.unsqueeze(-1) x = x.unsqueeze(-1)
...@@ -313,7 +304,6 @@ def fourier_encode(x, max_freq, num_bands = 4): ...@@ -313,7 +304,6 @@ def fourier_encode(x, max_freq, num_bands = 4):
x = torch.cat((x, orig_x), dim = -1) x = torch.cat((x, orig_x), dim = -1)
return x return x
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
class TransformerSlotAttention(nn.Module): class TransformerSlotAttention(nn.Module):
""" """
...@@ -393,4 +383,4 @@ class TransformerSlotAttention(nn.Module): ...@@ -393,4 +383,4 @@ class TransformerSlotAttention(nn.Module):
x_d = self_attn(inputs) + inputs x_d = self_attn(inputs) + inputs
inputs = self_ff(x_d) + x_d inputs = self_ff(x_d) + x_d
return slots # [batch_size, num_slots, dim] return slots, None, None # [batch_size, num_slots, dim]
from torch import nn from torch import nn
import torch import torch
import torch.nn.functional as F
import numpy as np import numpy as np
from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed, TransformerSlotAttention from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention
import osrt.layers as layers import osrt.layers as layers
class OSRT(nn.Module): class OSRT(nn.Module):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
...@@ -75,39 +78,30 @@ class SlotAttentionAutoEncoder(nn.Module): ...@@ -75,39 +78,30 @@ class SlotAttentionAutoEncoder(nn.Module):
self.num_iterations = num_iterations self.num_iterations = num_iterations
self.encoder_cnn = nn.Sequential( self.encoder_cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.ReLU(), nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.ReLU(), nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True)
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.ReLU()
) )
self.decoder_initial_size = (8, 8) self.decoder_initial_size = (8, 8)
self.decoder_cnn = nn.Sequential( self.decoder_cnn = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
nn.ReLU(), nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
nn.ReLU(), nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
nn.ReLU(), nn.ConvTranspose2d(64, 4, kernel_size=3)
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(1, 1), padding=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1)
) )
self.encoder_pos = SoftPositionEmbed(64, self.resolution) self.encoder_pos = PositionEmbeddingImplicit(64)
self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size) self.decoder_pos = PositionEmbeddingImplicit(64)
self.layer_norm = nn.LayerNorm(64) self.layer_norm = nn.LayerNorm(64)
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
JaxLinear(64, 64), nn.Linear(64, 64),
nn.ReLU(), nn.ReLU(inplace=True),
JaxLinear(64, 64) nn.Linear(64, 64)
) )
model_type = cfg['model']['model_type'] model_type = cfg['model']['model_type']
...@@ -128,35 +122,22 @@ class SlotAttentionAutoEncoder(nn.Module): ...@@ -128,35 +122,22 @@ class SlotAttentionAutoEncoder(nn.Module):
depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model
def forward(self, image): def forward(self, image):
# `image` has shape: [batch_size, num_channels, width, height]. x = self.encoder_cnn(image).movedim(1, -1)
# Convolutional encoder with position embedding. x = self.encoder_pos(x)
x = self.encoder_cnn(image) # CNN Backbone. x = self.mlp(self.layer_norm(x))
x = self.encoder_pos(x).permute(0, 2, 3, 1) # Position embedding.
x = spatial_flatten(x) # Flatten spatial dimensions (treat image as set). slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots)
x = self.mlp(self.layer_norm(x)) # Feedforward network on set. x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
# `x` has shape: [batch_size, width*height, input_size].
# Slot Attention module.
slots = self.slot_attention(x)
# `slots` has shape: [batch_size, num_slots, slot_size].
# Spatial broadcast decoder.
x = spatial_broadcast(slots, self.decoder_initial_size).permute(0, 3, 1, 2)
# `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
x = self.decoder_pos(x) x = self.decoder_pos(x)
x = self.decoder_cnn(x).permute(0, 2, 3, 1) x = self.decoder_cnn(x.movedim(-1, 1))
# `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode)
# Undo combination of slot and batch dimension; split alpha masks. x = x.unflatten(0, (len(image), len(x) // len(image)))
recons, masks = unstack_and_split(x, batch_size=image.shape[0])
# `recons` has shape: [batch_size, num_slots, width, height, num_channels].
# `masks` has shape: [batch_size, num_slots, width, height, 1].
# Normalize alpha masks over slots. recons, masks = x.split((3, 1), dim = 2)
masks = torch.softmax(masks, dim=1) masks = masks.softmax(dim = 1)
recon_combined = torch.sum(recons * masks, dim=1) # Recombine image. recon_combined = (recons * masks).sum(dim = 1)
# `recon_combined` has shape: [batch_size, width, height, num_channels].
return recon_combined, recons, masks, slots return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
import datetime
import time
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import yaml
from osrt.model import SlotAttentionAutoEncoder
from osrt import data
from osrt.utils.visualize import visualize_slot_attention
from osrt.utils.common import mse2psnr
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
def main():
# Arguments
parser = argparse.ArgumentParser(
description='Train a 3D scene representation model.'
)
parser.add_argument('config', type=str, help="Where to save the checkpoints.")
parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
args = parser.parse_args()
with open(args.config, 'r') as f:
cfg = yaml.load(f, Loader=yaml.CLoader)
### Set random seed.
torch.manual_seed(args.seed)
### Hyperparameters of the model.
num_slots = cfg["model"]["num_slots"]
num_iterations = cfg["model"]["iters"]
base_learning_rate = 0.0004
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resolution = (128, 128)
#### Create datasets
vis_dataset = data.get_dataset('test', cfg['data'])
vis_loader = DataLoader(
vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
#### Create model
model = SlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device)
num_params = sum(p.numel() for p in model.parameters())
print('Number of parameters:')
print(f'Model slot attention: {num_params}')
optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08)
ckpt = {
'network': model,
'optimizer': optimizer,
'global_step': 1639
}
#ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth')
"""ckpt = torch.load('~/ckpt.pth')
model = ckpt['network']"""
model.load_state_dict(torch.load('/home/achapin/ckpt.pth')["model_state_dict"])
image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
image = F.interpolate(image, size=128)
image = image.to(device)
recon_combined, recons, masks, slots = model(image)
loss = nn.MSELoss()
input_image = image.permute(0, 2, 3, 1)
loss_value = loss(recon_combined, input_image)
psnr = mse2psnr(loss_value)
print(f"MSE value : {loss_value} VS PSNR {psnr}")
visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=1639, save_file=True)
if __name__ == "__main__":
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment