-
Alexandre Chapin authored96ec6825
model.py 7.59 KiB
from typing import Any
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention
import osrt.layers as layers
from osrt.utils.common import mse2psnr
import lightning as pl
class OSRT(nn.Module):
def __init__(self, cfg):
super().__init__()
encoder_type = cfg['encoder']
decoder_type = cfg['decoder']
layers.__USE_DEFAULT_INIT__ = cfg.get('use_default_init', False)
if encoder_type == 'srt':
self.encoder = ImprovedSRTEncoder(**cfg['encoder_kwargs'])
elif encoder_type == 'osrt':
self.encoder = OSRTEncoder(**cfg['encoder_kwargs'])
elif encoder_type == 'sam':
self.encoder = FeatureMasking(**cfg['encoder_kwargs'])
else:
raise ValueError(f'Unknown encoder type: {encoder_type}')
if decoder_type == 'spatial_broadcast':
self.decoder = SpatialBroadcastDecoder(**cfg['decoder_kwargs'])
elif decoder_type == 'srt':
self.decoder = ImprovedSRTDecoder(**cfg['decoder_kwargs'])
elif decoder_type == 'slot_mixer':
self.decoder = SlotMixerDecoder(**cfg['decoder_kwargs'])
else:
raise ValueError(f'Unknown decoder type: {decoder_type}')
class LitSlotAttentionAutoEncoder(pl.LightningModule):
"""
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
"""
def __init__(self, resolution, num_slots, num_iterations, cfg):
"""Builds the Slot Attention-based auto-encoder.
Args:
resolution: Tuple of integers specifying width and height of input image.
num_slots: Number of slots in Slot Attention.
num_iterations: Number of iterations in Slot Attention.
"""
super().__init__()
self.resolution = resolution
self.num_slots = num_slots
self.num_iterations = num_iterations
self.encoder_cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True)
)
self.decoder_initial_size = (8, 8)
self.decoder_cnn = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 4, kernel_size=3)
)
self.encoder_pos = PositionEmbeddingImplicit(64)
self.decoder_pos = PositionEmbeddingImplicit(64)
self.layer_norm = nn.LayerNorm(64)
self.mlp = nn.Sequential(
nn.Linear(64, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 64)
)
model_type = cfg['model']['model_type']
if model_type == 'sa':
self.slot_attention = SlotAttention(
num_slots=self.num_slots,
input_dim=64,
slot_dim=64,
hidden_dim=128,
iters=self.num_iterations)
elif model_type == 'tsa':
# We set the same number of inside parameters
self.slot_attention = TransformerSlotAttention(
num_slots=self.num_slots,
input_dim=64,
slot_dim=64,
hidden_dim=128,
depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model
def forward(self, image):
x = self.encoder_cnn(image).movedim(1, -1)
x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x))
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots)
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1))
x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode)
x = x.unflatten(0, (len(image), len(x) // len(image)))
recons, masks = x.split((3, 1), dim = 2)
masks = masks.softmax(dim = 1)
recon_combined = (recons * masks).sum(dim = 1)
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
def configure_optimizers(self) -> Any:
optimizer = optim.Adam(self.parameters, lr=1e-3, eps=1e-08)
return optimizer
def one_step(self, image):
x = self.encoder_cnn(image).movedim(1, -1)
x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x))
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots)
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1))
x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode)
x = x.unflatten(0, (len(image), len(x) // len(image)))
recons, masks = x.split((3, 1), dim = 2)
masks = masks.softmax(dim = 1)
recon_combined = (recons * masks).sum(dim = 1)
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
def training_step(self, batch, criterion):
"""Perform a single training step."""
input_image = torch.squeeze(batch.get('input_images'), dim=1)
input_image = F.interpolate(input_image, size=128)
# Get the prediction of the model and compute the loss.
preds = self.one_step(input_image)
recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
# Get and apply gradients.
self.optimizer.zero_grad()
loss_value.backward()
self.optimizer.step()
self.log('train_mse', loss_value, on_epoch=True)
return loss_value.item()
def validation_step(self, batch, criterion):
"""Perform a single eval step."""
input_image = torch.squeeze(batch.get('input_images'), dim=1)
input_image = F.interpolate(input_image, size=128)
# Get the prediction of the model and compute the loss.
preds = self.one_step(input_image)
recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value)
self.log('val_mse', loss_value)
self.log('val_psnr', psnr)
return loss_value.item(), psnr.item()