diff --git a/osrt/model.py b/osrt/model.py index 3ecc2c26c107c32ab5bb4189565393639566b5b2..833dba682b5e589114ac6d771251816a676cccb0 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -1,6 +1,9 @@ +from typing import Any +from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import nn import torch import torch.nn.functional as F +import torch.optim as optim import numpy as np @@ -8,7 +11,9 @@ from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention import osrt.layers as layers +from osrt.utils.common import mse2psnr +import lightning as pl class OSRT(nn.Module): @@ -39,25 +44,7 @@ class OSRT(nn.Module): raise ValueError(f'Unknown decoder type: {decoder_type}') - -def unstack_and_split(x, batch_size, num_channels=3): - """Unstack batch dimension and split into channels and alpha mask.""" - unstacked = x.view(batch_size, -1, *x.shape[1:]) - channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1) - return channels, masks - -def spatial_flatten(x): - return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1]) - -def spatial_broadcast(slots, resolution): - """Broadcast slot features to a 2D grid and collapse slot dimension.""" - # `slots` has shape: [batch_size, num_slots, slot_size]. - slots = slots.view(-1, slots.shape[-1])[:, None, None, :] - grid = slots.repeat(1, resolution[0], resolution[1], 1) - # `grid` has shape: [batch_size*num_slots, width, height, slot_size]. - return grid - -class SlotAttentionAutoEncoder(nn.Module): +class LitSlotAttentionAutoEncoder(pl.LightningModule): """ Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings. @@ -140,4 +127,65 @@ class SlotAttentionAutoEncoder(nn.Module): recon_combined = (recons * masks).sum(dim = 1) return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) + + def configure_optimizers(self) -> Any: + optimizer = optim.Adam(self.parameters, lr=1e-3, eps=1e-08) + return optimizer + + def one_step(self, image): + x = self.encoder_cnn(image).movedim(1, -1) + x = self.encoder_pos(x) + x = self.mlp(self.layer_norm(x)) + + slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots) + x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1) + x = self.decoder_pos(x) + x = self.decoder_cnn(x.movedim(-1, 1)) + + x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode) + + x = x.unflatten(0, (len(image), len(x) // len(image))) + + recons, masks = x.split((3, 1), dim = 2) + masks = masks.softmax(dim = 1) + recon_combined = (recons * masks).sum(dim = 1) + + return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) + + def training_step(self, batch, criterion): + """Perform a single training step.""" + input_image = torch.squeeze(batch.get('input_images'), dim=1) + input_image = F.interpolate(input_image, size=128) + + # Get the prediction of the model and compute the loss. + preds = self.one_step(input_image) + recon_combined, recons, masks, slots = preds + input_image = input_image.permute(0, 2, 3, 1) + loss_value = criterion(recon_combined, input_image) + del recons, masks, slots # Unused. + + # Get and apply gradients. + self.optimizer.zero_grad() + loss_value.backward() + self.optimizer.step() + self.log('train_mse', loss_value, on_epoch=True) + + return loss_value.item() + + def validation_step(self, batch, criterion): + """Perform a single eval step.""" + input_image = torch.squeeze(batch.get('input_images'), dim=1) + input_image = F.interpolate(input_image, size=128) + + # Get the prediction of the model and compute the loss. + preds = self.one_step(input_image) + recon_combined, recons, masks, slots = preds + input_image = input_image.permute(0, 2, 3, 1) + loss_value = criterion(recon_combined, input_image) + del recons, masks, slots # Unused. + psnr = mse2psnr(loss_value) + self.log('val_mse', loss_value) + self.log('val_psnr', psnr) + + return loss_value.item(), psnr.item() diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml index a2600fe9c02fd808716ea8df86411ec68be392ac..1164186692f69964292e15c2bc6bd2ea9ed6a024 100644 --- a/runs/clevr3d/slot_att/config.yaml +++ b/runs/clevr3d/slot_att/config.yaml @@ -6,13 +6,10 @@ model: model_type: sa training: num_workers: 2 + num_gpus: 8 batch_size: 32 max_it: 333000000 warmup_it: 10000 decay_rate: 0.5 decay_it: 100000 - print_every: 1 - validate_every: 1 - checkpoint_every: 1 - visualize_every: 2 diff --git a/train_sa.py b/train_sa.py index 654a45ac3f98fff36e46413404bbb0aae5b30477..5b11bed6a9fee3e61118be7328ca65bc231efbed 100644 --- a/train_sa.py +++ b/train_sa.py @@ -2,52 +2,21 @@ import datetime import time import torch import torch.nn as nn -import torch.optim as optim import argparse import yaml -from osrt.model import SlotAttentionAutoEncoder +from osrt.model import LitSlotAttentionAutoEncoder from osrt import data from osrt.utils.visualize import visualize_slot_attention -from osrt.utils.common import mse2psnr from torch.utils.data import DataLoader import torch.nn.functional as F from tqdm import tqdm -def train_step(batch, model, optimizer, device, criterion): - """Perform a single training step.""" - input_image = torch.squeeze(batch.get('input_images').to(device), dim=1) - input_image = F.interpolate(input_image, size=128) +import lightning as pl +from lightning.pytorch.loggers.wandb import WandbLogger +from lightning.pytorch.callbacks import ModelCheckpoint - # Get the prediction of the model and compute the loss. - preds = model(input_image) - recon_combined, recons, masks, slots = preds - input_image = input_image.permute(0, 2, 3, 1) - loss_value = criterion(recon_combined, input_image) - del recons, masks, slots # Unused. - - # Get and apply gradients. - optimizer.zero_grad() - loss_value.backward() - optimizer.step() - - return loss_value.item() - -def eval_step(batch, model, device, criterion): - """Perform a single eval step.""" - input_image = torch.squeeze(batch.get('input_images').to(device), dim=1) - input_image = F.interpolate(input_image, size=128) - - # Get the prediction of the model and compute the loss. - preds = model(input_image) - recon_combined, recons, masks, slots = preds - input_image = input_image.permute(0, 2, 3, 1) - loss_value = criterion(recon_combined, input_image) - del recons, masks, slots # Unused. - psnr = mse2psnr(loss_value) - - return loss_value.item(), psnr.item() def main(): # Arguments @@ -64,20 +33,17 @@ def main(): cfg = yaml.load(f, Loader=yaml.CLoader) ### Set random seed. - torch.manual_seed(args.seed) + pl.seed_everything(42, workers=True) ### Hyperparameters of the model. batch_size = cfg["training"]["batch_size"] + num_gpus = cfg["training"]["num_gpus"] num_slots = cfg["model"]["num_slots"] num_iterations = cfg["model"]["iters"] - base_learning_rate = 0.0004 num_train_steps = cfg["training"]["max_it"] warmup_steps = cfg["training"]["warmup_it"] decay_rate = cfg["training"]["decay_rate"] decay_steps = cfg["training"]["decay_it"] - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - criterion = nn.MSELoss() - resolution = (128, 128) #### Create datasets @@ -90,71 +56,36 @@ def main(): val_loader = DataLoader( val_dataset, batch_size=batch_size, num_workers=1, shuffle=True, worker_init_fn=data.worker_init_fn) - - vis_dataset = data.get_dataset('test', cfg['data']) - vis_loader = DataLoader( - vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"], - shuffle=True, worker_init_fn=data.worker_init_fn) #### Create model - model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device) - num_params = sum(p.numel() for p in model.parameters()) - - print('Number of parameters:') - print(f'Model slot attention: {num_params}') - - optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08) - - #### Prepare checkpoint manager. - global_step = 0 - ckpt = { - 'network': model, - 'optimizer': optimizer, - 'global_step': global_step - } - ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth') - # ckpt = torch.load(args.ckpt + '/ckpt.pth') - model = ckpt['network'] - optimizer = ckpt['optimizer'] - global_step = ckpt['global_step'] - - """ TODO : setup wandb - if args.wandb: - if run_id is None: - run_id = wandb.util.generate_id() - print(f'Sampled new wandb run_id {run_id}.') - else: - print(f'Resuming wandb with existing run_id {run_id}.') - # Tell in which mode to launch the logging in W&B (for offline cluster) - if args.offline_log: - mode = "offline" - else: - mode = "online" - wandb.init(project='osrt', name=os.path.dirname(args.config), - id=run_id, resume=True, mode=mode, sync_tensorboard=True) - wandb.config = cfg""" - - start = time.time() - epochs = num_train_steps // len(train_loader) - for epoch in range(epochs): - total_loss = 0 - model.train() - for batch in tqdm(train_loader): - # Learning rate warm-up. - if global_step < warmup_steps: - learning_rate = base_learning_rate * global_step / warmup_steps - else: - learning_rate = base_learning_rate - learning_rate = learning_rate * (decay_rate ** (global_step / decay_steps)) - for param_group in optimizer.param_groups: - param_group['lr'] = learning_rate - - total_loss += train_step(batch, model, optimizer, device, criterion) - global_step += 1 - - total_loss /= len(train_loader) - # We save the checkpoints - if not epoch % cfg["training"]["checkpoint_every"]: + model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) + + wandb_logger = WandbLogger() + + checkpoint_callback = ModelCheckpoint( + save_top_k=10, + monitor="val_psnr", + mode="max", + dirpath="./checkpoints" if cfg["model"]["model_type"] == "sa" else "./checkpoints_tsa", + filename="slot_att-clevr3d-{epoch:02d}-psnr{val_psnr:.2f}.pth", + ) + + trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", + default_root_dir="./logs", logger=wandb_logger, + strategy="ddp" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], deterministic=True, + log_every_n_steps=100, max_steps=num_train_steps) + + trainer.fit(model, train_loader, val_loader) + +if __name__ == "__main__": + main() + + +#print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}") + +""" + +if not epoch % cfg["training"]["checkpoint_every"]: # Save the checkpoint of the model. ckpt['global_step'] = global_step ckpt['model_state_dict'] = model.state_dict() @@ -163,27 +94,5 @@ def main(): # We visualize some test data if not epoch % cfg["training"]["visualize_every"]: - image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1) - image = F.interpolate(image, size=128) - image = image.to(device) - recon_combined, recons, masks, slots = model(image) - visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=global_step, save_file=True) - # Log the training loss. - if not epoch % cfg["training"]["print_every"]: - print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}") - # We visualize some test data - if not epoch % cfg["training"]["validate_every"]: - val_loss = 0 - val_psnr = 0 - model.eval() - for batch in tqdm(val_loader): - mse, psnr = eval_step(batch, model, device, criterion) - val_loss += mse - val_psnr += psnr - val_loss /= len(val_loader) - val_psnr /= len(val_loader) - print(f"[EVAL] Epoch : {epoch} || Loss (MSE): {val_loss}; PSNR: {val_psnr}, Time: {datetime.timedelta(seconds=time.time() - start)}") - model.train() - -if __name__ == "__main__": - main() \ No newline at end of file + +""" \ No newline at end of file diff --git a/visualise.py b/visualise.py index 05c7d2ea84833a8fefb9855379961b6e9f1b6cb0..677f46ad352f9cf2700ce046b525b7bf740e036a 100644 --- a/visualise.py +++ b/visualise.py @@ -6,7 +6,7 @@ import torch.optim as optim import argparse import yaml -from osrt.model import SlotAttentionAutoEncoder +from osrt.model import LitSlotAttentionAutoEncoder from osrt import data from osrt.utils.visualize import visualize_slot_attention from osrt.utils.common import mse2psnr @@ -15,6 +15,8 @@ from torch.utils.data import DataLoader import torch.nn.functional as F from tqdm import tqdm +# TODO : setup with lightning + def main(): # Arguments parser = argparse.ArgumentParser( @@ -48,7 +50,7 @@ def main(): shuffle=True, worker_init_fn=data.worker_init_fn) #### Create model - model = SlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device) + model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg) num_params = sum(p.numel() for p in model.parameters()) print('Number of parameters:')