Skip to content
Snippets Groups Projects
Commit 349fae70 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Implement lightning Slot Attention

parent 672cd717
No related branches found
No related tags found
No related merge requests found
from typing import Any
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn from torch import nn
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim
import numpy as np import numpy as np
...@@ -8,7 +11,9 @@ from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking ...@@ -8,7 +11,9 @@ from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention
import osrt.layers as layers import osrt.layers as layers
from osrt.utils.common import mse2psnr
import lightning as pl
class OSRT(nn.Module): class OSRT(nn.Module):
...@@ -39,25 +44,7 @@ class OSRT(nn.Module): ...@@ -39,25 +44,7 @@ class OSRT(nn.Module):
raise ValueError(f'Unknown decoder type: {decoder_type}') raise ValueError(f'Unknown decoder type: {decoder_type}')
class LitSlotAttentionAutoEncoder(pl.LightningModule):
def unstack_and_split(x, batch_size, num_channels=3):
"""Unstack batch dimension and split into channels and alpha mask."""
unstacked = x.view(batch_size, -1, *x.shape[1:])
channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1)
return channels, masks
def spatial_flatten(x):
return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1])
def spatial_broadcast(slots, resolution):
"""Broadcast slot features to a 2D grid and collapse slot dimension."""
# `slots` has shape: [batch_size, num_slots, slot_size].
slots = slots.view(-1, slots.shape[-1])[:, None, None, :]
grid = slots.repeat(1, resolution[0], resolution[1], 1)
# `grid` has shape: [batch_size*num_slots, width, height, slot_size].
return grid
class SlotAttentionAutoEncoder(nn.Module):
""" """
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings. Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
...@@ -140,4 +127,65 @@ class SlotAttentionAutoEncoder(nn.Module): ...@@ -140,4 +127,65 @@ class SlotAttentionAutoEncoder(nn.Module):
recon_combined = (recons * masks).sum(dim = 1) recon_combined = (recons * masks).sum(dim = 1)
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
def configure_optimizers(self) -> Any:
optimizer = optim.Adam(self.parameters, lr=1e-3, eps=1e-08)
return optimizer
def one_step(self, image):
x = self.encoder_cnn(image).movedim(1, -1)
x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x))
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots)
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1))
x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode)
x = x.unflatten(0, (len(image), len(x) // len(image)))
recons, masks = x.split((3, 1), dim = 2)
masks = masks.softmax(dim = 1)
recon_combined = (recons * masks).sum(dim = 1)
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
def training_step(self, batch, criterion):
"""Perform a single training step."""
input_image = torch.squeeze(batch.get('input_images'), dim=1)
input_image = F.interpolate(input_image, size=128)
# Get the prediction of the model and compute the loss.
preds = self.one_step(input_image)
recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
# Get and apply gradients.
self.optimizer.zero_grad()
loss_value.backward()
self.optimizer.step()
self.log('train_mse', loss_value, on_epoch=True)
return loss_value.item()
def validation_step(self, batch, criterion):
"""Perform a single eval step."""
input_image = torch.squeeze(batch.get('input_images'), dim=1)
input_image = F.interpolate(input_image, size=128)
# Get the prediction of the model and compute the loss.
preds = self.one_step(input_image)
recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value)
self.log('val_mse', loss_value)
self.log('val_psnr', psnr)
return loss_value.item(), psnr.item()
...@@ -6,13 +6,10 @@ model: ...@@ -6,13 +6,10 @@ model:
model_type: sa model_type: sa
training: training:
num_workers: 2 num_workers: 2
num_gpus: 8
batch_size: 32 batch_size: 32
max_it: 333000000 max_it: 333000000
warmup_it: 10000 warmup_it: 10000
decay_rate: 0.5 decay_rate: 0.5
decay_it: 100000 decay_it: 100000
print_every: 1
validate_every: 1
checkpoint_every: 1
visualize_every: 2
...@@ -2,52 +2,21 @@ import datetime ...@@ -2,52 +2,21 @@ import datetime
import time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim
import argparse import argparse
import yaml import yaml
from osrt.model import SlotAttentionAutoEncoder from osrt.model import LitSlotAttentionAutoEncoder
from osrt import data from osrt import data
from osrt.utils.visualize import visualize_slot_attention from osrt.utils.visualize import visualize_slot_attention
from osrt.utils.common import mse2psnr
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.nn.functional as F import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
def train_step(batch, model, optimizer, device, criterion): import lightning as pl
"""Perform a single training step.""" from lightning.pytorch.loggers.wandb import WandbLogger
input_image = torch.squeeze(batch.get('input_images').to(device), dim=1) from lightning.pytorch.callbacks import ModelCheckpoint
input_image = F.interpolate(input_image, size=128)
# Get the prediction of the model and compute the loss.
preds = model(input_image)
recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
# Get and apply gradients.
optimizer.zero_grad()
loss_value.backward()
optimizer.step()
return loss_value.item()
def eval_step(batch, model, device, criterion):
"""Perform a single eval step."""
input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
input_image = F.interpolate(input_image, size=128)
# Get the prediction of the model and compute the loss.
preds = model(input_image)
recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value)
return loss_value.item(), psnr.item()
def main(): def main():
# Arguments # Arguments
...@@ -64,20 +33,17 @@ def main(): ...@@ -64,20 +33,17 @@ def main():
cfg = yaml.load(f, Loader=yaml.CLoader) cfg = yaml.load(f, Loader=yaml.CLoader)
### Set random seed. ### Set random seed.
torch.manual_seed(args.seed) pl.seed_everything(42, workers=True)
### Hyperparameters of the model. ### Hyperparameters of the model.
batch_size = cfg["training"]["batch_size"] batch_size = cfg["training"]["batch_size"]
num_gpus = cfg["training"]["num_gpus"]
num_slots = cfg["model"]["num_slots"] num_slots = cfg["model"]["num_slots"]
num_iterations = cfg["model"]["iters"] num_iterations = cfg["model"]["iters"]
base_learning_rate = 0.0004
num_train_steps = cfg["training"]["max_it"] num_train_steps = cfg["training"]["max_it"]
warmup_steps = cfg["training"]["warmup_it"] warmup_steps = cfg["training"]["warmup_it"]
decay_rate = cfg["training"]["decay_rate"] decay_rate = cfg["training"]["decay_rate"]
decay_steps = cfg["training"]["decay_it"] decay_steps = cfg["training"]["decay_it"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.MSELoss()
resolution = (128, 128) resolution = (128, 128)
#### Create datasets #### Create datasets
...@@ -90,71 +56,36 @@ def main(): ...@@ -90,71 +56,36 @@ def main():
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, batch_size=batch_size, num_workers=1, val_dataset, batch_size=batch_size, num_workers=1,
shuffle=True, worker_init_fn=data.worker_init_fn) shuffle=True, worker_init_fn=data.worker_init_fn)
vis_dataset = data.get_dataset('test', cfg['data'])
vis_loader = DataLoader(
vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
#### Create model #### Create model
model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device) model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
num_params = sum(p.numel() for p in model.parameters())
wandb_logger = WandbLogger()
print('Number of parameters:')
print(f'Model slot attention: {num_params}') checkpoint_callback = ModelCheckpoint(
save_top_k=10,
optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08) monitor="val_psnr",
mode="max",
#### Prepare checkpoint manager. dirpath="./checkpoints" if cfg["model"]["model_type"] == "sa" else "./checkpoints_tsa",
global_step = 0 filename="slot_att-clevr3d-{epoch:02d}-psnr{val_psnr:.2f}.pth",
ckpt = { )
'network': model,
'optimizer': optimizer, trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple",
'global_step': global_step default_root_dir="./logs", logger=wandb_logger,
} strategy="ddp" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], deterministic=True,
ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth') log_every_n_steps=100, max_steps=num_train_steps)
# ckpt = torch.load(args.ckpt + '/ckpt.pth')
model = ckpt['network'] trainer.fit(model, train_loader, val_loader)
optimizer = ckpt['optimizer']
global_step = ckpt['global_step'] if __name__ == "__main__":
main()
""" TODO : setup wandb
if args.wandb:
if run_id is None: #print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")
run_id = wandb.util.generate_id()
print(f'Sampled new wandb run_id {run_id}.') """
else:
print(f'Resuming wandb with existing run_id {run_id}.') if not epoch % cfg["training"]["checkpoint_every"]:
# Tell in which mode to launch the logging in W&B (for offline cluster)
if args.offline_log:
mode = "offline"
else:
mode = "online"
wandb.init(project='osrt', name=os.path.dirname(args.config),
id=run_id, resume=True, mode=mode, sync_tensorboard=True)
wandb.config = cfg"""
start = time.time()
epochs = num_train_steps // len(train_loader)
for epoch in range(epochs):
total_loss = 0
model.train()
for batch in tqdm(train_loader):
# Learning rate warm-up.
if global_step < warmup_steps:
learning_rate = base_learning_rate * global_step / warmup_steps
else:
learning_rate = base_learning_rate
learning_rate = learning_rate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate
total_loss += train_step(batch, model, optimizer, device, criterion)
global_step += 1
total_loss /= len(train_loader)
# We save the checkpoints
if not epoch % cfg["training"]["checkpoint_every"]:
# Save the checkpoint of the model. # Save the checkpoint of the model.
ckpt['global_step'] = global_step ckpt['global_step'] = global_step
ckpt['model_state_dict'] = model.state_dict() ckpt['model_state_dict'] = model.state_dict()
...@@ -163,27 +94,5 @@ def main(): ...@@ -163,27 +94,5 @@ def main():
# We visualize some test data # We visualize some test data
if not epoch % cfg["training"]["visualize_every"]: if not epoch % cfg["training"]["visualize_every"]:
image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
image = F.interpolate(image, size=128) """
image = image.to(device) \ No newline at end of file
recon_combined, recons, masks, slots = model(image)
visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=global_step, save_file=True)
# Log the training loss.
if not epoch % cfg["training"]["print_every"]:
print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")
# We visualize some test data
if not epoch % cfg["training"]["validate_every"]:
val_loss = 0
val_psnr = 0
model.eval()
for batch in tqdm(val_loader):
mse, psnr = eval_step(batch, model, device, criterion)
val_loss += mse
val_psnr += psnr
val_loss /= len(val_loader)
val_psnr /= len(val_loader)
print(f"[EVAL] Epoch : {epoch} || Loss (MSE): {val_loss}; PSNR: {val_psnr}, Time: {datetime.timedelta(seconds=time.time() - start)}")
model.train()
if __name__ == "__main__":
main()
\ No newline at end of file
...@@ -6,7 +6,7 @@ import torch.optim as optim ...@@ -6,7 +6,7 @@ import torch.optim as optim
import argparse import argparse
import yaml import yaml
from osrt.model import SlotAttentionAutoEncoder from osrt.model import LitSlotAttentionAutoEncoder
from osrt import data from osrt import data
from osrt.utils.visualize import visualize_slot_attention from osrt.utils.visualize import visualize_slot_attention
from osrt.utils.common import mse2psnr from osrt.utils.common import mse2psnr
...@@ -15,6 +15,8 @@ from torch.utils.data import DataLoader ...@@ -15,6 +15,8 @@ from torch.utils.data import DataLoader
import torch.nn.functional as F import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
# TODO : setup with lightning
def main(): def main():
# Arguments # Arguments
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -48,7 +50,7 @@ def main(): ...@@ -48,7 +50,7 @@ def main():
shuffle=True, worker_init_fn=data.worker_init_fn) shuffle=True, worker_init_fn=data.worker_init_fn)
#### Create model #### Create model
model = SlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device) model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg)
num_params = sum(p.numel() for p in model.parameters()) num_params = sum(p.numel() for p in model.parameters())
print('Number of parameters:') print('Number of parameters:')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment