diff --git a/osrt/model.py b/osrt/model.py
index 3ecc2c26c107c32ab5bb4189565393639566b5b2..833dba682b5e589114ac6d771251816a676cccb0 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -1,6 +1,9 @@
+from typing import Any
+from lightning.pytorch.utilities.types import STEP_OUTPUT
 from torch import nn
 import torch
 import torch.nn.functional as F
+import torch.optim as optim
 
 import numpy as np
 
@@ -8,7 +11,9 @@ from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
 from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
 from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention
 import osrt.layers as layers
+from osrt.utils.common import mse2psnr
 
+import lightning as pl
 
 
 class OSRT(nn.Module):
@@ -39,25 +44,7 @@ class OSRT(nn.Module):
             raise ValueError(f'Unknown decoder type: {decoder_type}')
 
 
-
-def unstack_and_split(x, batch_size, num_channels=3):
-    """Unstack batch dimension and split into channels and alpha mask."""
-    unstacked = x.view(batch_size, -1, *x.shape[1:])
-    channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1)
-    return channels, masks
-
-def spatial_flatten(x):
-    return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1])
-
-def spatial_broadcast(slots, resolution):
-    """Broadcast slot features to a 2D grid and collapse slot dimension."""
-    # `slots` has shape: [batch_size, num_slots, slot_size].
-    slots = slots.view(-1, slots.shape[-1])[:, None, None, :]
-    grid = slots.repeat(1, resolution[0], resolution[1], 1)
-    # `grid` has shape: [batch_size*num_slots, width, height, slot_size].
-    return grid
-
-class SlotAttentionAutoEncoder(nn.Module):
+class LitSlotAttentionAutoEncoder(pl.LightningModule):
     """
     Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
 
@@ -140,4 +127,65 @@ class SlotAttentionAutoEncoder(nn.Module):
         recon_combined = (recons * masks).sum(dim = 1)
 
         return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
+    
+    def configure_optimizers(self) -> Any:
+        optimizer = optim.Adam(self.parameters, lr=1e-3, eps=1e-08)
+        return optimizer
+    
+    def one_step(self, image):
+        x = self.encoder_cnn(image).movedim(1, -1)
+        x = self.encoder_pos(x)
+        x = self.mlp(self.layer_norm(x))
+        
+        slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots)
+        x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
+        x = self.decoder_pos(x)
+        x = self.decoder_cnn(x.movedim(-1, 1))
+        
+        x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode)
+
+        x = x.unflatten(0, (len(image), len(x) // len(image)))
+
+        recons, masks = x.split((3, 1), dim = 2)
+        masks = masks.softmax(dim = 1)
+        recon_combined = (recons * masks).sum(dim = 1)
+
+        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
+
+    def training_step(self, batch, criterion):
+        """Perform a single training step."""
+        input_image = torch.squeeze(batch.get('input_images'), dim=1)
+        input_image = F.interpolate(input_image, size=128)
+
+        # Get the prediction of the model and compute the loss.
+        preds = self.one_step(input_image)
+        recon_combined, recons, masks, slots = preds
+        input_image = input_image.permute(0, 2, 3, 1)
+        loss_value = criterion(recon_combined, input_image)
+        del recons, masks, slots  # Unused.
+
+        # Get and apply gradients.
+        self.optimizer.zero_grad()
+        loss_value.backward()
+        self.optimizer.step()
+        self.log('train_mse', loss_value, on_epoch=True)
+
+        return loss_value.item()
+    
+    def validation_step(self, batch, criterion):
+        """Perform a single eval step."""
+        input_image = torch.squeeze(batch.get('input_images'), dim=1)
+        input_image = F.interpolate(input_image, size=128)
+
+        # Get the prediction of the model and compute the loss.
+        preds = self.one_step(input_image)
+        recon_combined, recons, masks, slots = preds
+        input_image = input_image.permute(0, 2, 3, 1)
+        loss_value = criterion(recon_combined, input_image)
+        del recons, masks, slots  # Unused.
+        psnr = mse2psnr(loss_value)
+        self.log('val_mse', loss_value)
+        self.log('val_psnr', psnr)
+
+        return loss_value.item(), psnr.item()
 
diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml
index a2600fe9c02fd808716ea8df86411ec68be392ac..1164186692f69964292e15c2bc6bd2ea9ed6a024 100644
--- a/runs/clevr3d/slot_att/config.yaml
+++ b/runs/clevr3d/slot_att/config.yaml
@@ -6,13 +6,10 @@ model:
   model_type: sa
 training:
   num_workers: 2 
+  num_gpus: 8
   batch_size: 32 
   max_it: 333000000
   warmup_it: 10000
   decay_rate: 0.5
   decay_it: 100000
-  print_every: 1
-  validate_every: 1
-  checkpoint_every: 1
-  visualize_every: 2
 
diff --git a/train_sa.py b/train_sa.py
index 654a45ac3f98fff36e46413404bbb0aae5b30477..5b11bed6a9fee3e61118be7328ca65bc231efbed 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -2,52 +2,21 @@ import datetime
 import time
 import torch
 import torch.nn as nn
-import torch.optim as optim
 import argparse
 import yaml
 
-from osrt.model import SlotAttentionAutoEncoder
+from osrt.model import LitSlotAttentionAutoEncoder
 from osrt import data
 from osrt.utils.visualize import visualize_slot_attention
-from osrt.utils.common import mse2psnr
 
 from torch.utils.data import DataLoader
 import torch.nn.functional as F
 from tqdm import tqdm
 
-def train_step(batch, model, optimizer, device, criterion):
-    """Perform a single training step."""
-    input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
-    input_image = F.interpolate(input_image, size=128)
+import lightning as pl
+from lightning.pytorch.loggers.wandb import WandbLogger
+from lightning.pytorch.callbacks import ModelCheckpoint
 
-    # Get the prediction of the model and compute the loss.
-    preds = model(input_image)
-    recon_combined, recons, masks, slots = preds
-    input_image = input_image.permute(0, 2, 3, 1)
-    loss_value = criterion(recon_combined, input_image)
-    del recons, masks, slots  # Unused.
-
-    # Get and apply gradients.
-    optimizer.zero_grad()
-    loss_value.backward()
-    optimizer.step()
-
-    return loss_value.item()
-
-def eval_step(batch, model, device, criterion):
-    """Perform a single eval step."""
-    input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
-    input_image = F.interpolate(input_image, size=128)
-
-    # Get the prediction of the model and compute the loss.
-    preds = model(input_image)
-    recon_combined, recons, masks, slots = preds
-    input_image = input_image.permute(0, 2, 3, 1)
-    loss_value = criterion(recon_combined, input_image)
-    del recons, masks, slots  # Unused.
-    psnr = mse2psnr(loss_value)
-
-    return loss_value.item(), psnr.item()
 
 def main():
     # Arguments
@@ -64,20 +33,17 @@ def main():
         cfg = yaml.load(f, Loader=yaml.CLoader)
 
     ### Set random seed.
-    torch.manual_seed(args.seed)
+    pl.seed_everything(42, workers=True)
 
     ### Hyperparameters of the model.
     batch_size = cfg["training"]["batch_size"]
+    num_gpus = cfg["training"]["num_gpus"]
     num_slots = cfg["model"]["num_slots"]
     num_iterations = cfg["model"]["iters"]
-    base_learning_rate = 0.0004
     num_train_steps = cfg["training"]["max_it"]
     warmup_steps = cfg["training"]["warmup_it"]
     decay_rate = cfg["training"]["decay_rate"]
     decay_steps = cfg["training"]["decay_it"]
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-    criterion = nn.MSELoss()
-
     resolution = (128, 128)
     
     #### Create datasets
@@ -90,71 +56,36 @@ def main():
     val_loader = DataLoader(
         val_dataset, batch_size=batch_size, num_workers=1,
         shuffle=True, worker_init_fn=data.worker_init_fn)
-    
-    vis_dataset = data.get_dataset('test', cfg['data'])
-    vis_loader = DataLoader(
-        vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"],
-        shuffle=True, worker_init_fn=data.worker_init_fn)
 
     #### Create model
-    model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device)
-    num_params = sum(p.numel() for p in model.parameters())
-
-    print('Number of parameters:')
-    print(f'Model slot attention: {num_params}')
-
-    optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08)
-
-    #### Prepare checkpoint manager.
-    global_step = 0
-    ckpt = {
-        'network': model,
-        'optimizer': optimizer,
-        'global_step': global_step
-    }
-    ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth')
-    # ckpt = torch.load(args.ckpt + '/ckpt.pth')
-    model = ckpt['network']
-    optimizer = ckpt['optimizer']
-    global_step = ckpt['global_step']
-
-    """ TODO : setup wandb
-    if args.wandb:
-        if run_id is None:
-            run_id =  wandb.util.generate_id()
-            print(f'Sampled new wandb run_id {run_id}.')
-        else:
-            print(f'Resuming wandb with existing run_id {run_id}.')
-        # Tell in which mode to launch the logging in W&B (for offline cluster)
-        if args.offline_log:
-            mode = "offline"
-        else:
-            mode = "online"
-        wandb.init(project='osrt', name=os.path.dirname(args.config),
-                   id=run_id, resume=True, mode=mode, sync_tensorboard=True) 
-        wandb.config = cfg"""
-
-    start = time.time()
-    epochs = num_train_steps // len(train_loader)
-    for epoch in range(epochs):
-        total_loss = 0
-        model.train()
-        for batch in tqdm(train_loader):
-            # Learning rate warm-up.
-            if global_step < warmup_steps:
-                learning_rate = base_learning_rate * global_step / warmup_steps
-            else:
-                learning_rate = base_learning_rate
-            learning_rate = learning_rate * (decay_rate ** (global_step / decay_steps))
-            for param_group in optimizer.param_groups:
-                param_group['lr'] = learning_rate
-
-            total_loss += train_step(batch, model, optimizer, device, criterion)
-            global_step += 1
-
-        total_loss /= len(train_loader)
-        # We save the checkpoints
-        if not epoch % cfg["training"]["checkpoint_every"]:
+    model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
+
+    wandb_logger = WandbLogger()
+
+    checkpoint_callback = ModelCheckpoint(
+        save_top_k=10,
+        monitor="val_psnr",
+        mode="max",
+        dirpath="./checkpoints" if cfg["model"]["model_type"] == "sa" else "./checkpoints_tsa",
+        filename="slot_att-clevr3d-{epoch:02d}-psnr{val_psnr:.2f}.pth",
+    )
+
+    trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", 
+                         default_root_dir="./logs", logger=wandb_logger, 
+                         strategy="ddp" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], deterministic=True,
+                         log_every_n_steps=100, max_steps=num_train_steps)
+
+    trainer.fit(model, train_loader, val_loader)
+                
+if __name__ == "__main__":
+    main()
+
+
+#print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")
+
+"""
+
+if not epoch % cfg["training"]["checkpoint_every"]:
             # Save the checkpoint of the model.
             ckpt['global_step'] = global_step
             ckpt['model_state_dict'] = model.state_dict()
@@ -163,27 +94,5 @@ def main():
 
         # We visualize some test data
         if not epoch % cfg["training"]["visualize_every"]:
-            image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
-            image = F.interpolate(image, size=128)
-            image = image.to(device)
-            recon_combined, recons, masks, slots = model(image)
-            visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=global_step, save_file=True)
-        # Log the training loss.
-        if not epoch % cfg["training"]["print_every"]:
-            print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")
-        # We visualize some test data
-        if not epoch % cfg["training"]["validate_every"]:
-            val_loss = 0
-            val_psnr = 0
-            model.eval()
-            for batch in tqdm(val_loader):
-                mse, psnr = eval_step(batch, model, device, criterion)
-                val_loss += mse
-                val_psnr += psnr
-            val_loss /= len(val_loader)
-            val_psnr /= len(val_loader)
-            print(f"[EVAL] Epoch : {epoch} || Loss (MSE): {val_loss}; PSNR: {val_psnr}, Time: {datetime.timedelta(seconds=time.time() - start)}")
-            model.train()
-                        
-if __name__ == "__main__":
-    main()
\ No newline at end of file
+            
+"""
\ No newline at end of file
diff --git a/visualise.py b/visualise.py
index 05c7d2ea84833a8fefb9855379961b6e9f1b6cb0..677f46ad352f9cf2700ce046b525b7bf740e036a 100644
--- a/visualise.py
+++ b/visualise.py
@@ -6,7 +6,7 @@ import torch.optim as optim
 import argparse
 import yaml
 
-from osrt.model import SlotAttentionAutoEncoder
+from osrt.model import LitSlotAttentionAutoEncoder
 from osrt import data
 from osrt.utils.visualize import visualize_slot_attention
 from osrt.utils.common import mse2psnr
@@ -15,6 +15,8 @@ from torch.utils.data import DataLoader
 import torch.nn.functional as F
 from tqdm import tqdm
 
+# TODO : setup with lightning
+
 def main():
     # Arguments
     parser = argparse.ArgumentParser(
@@ -48,7 +50,7 @@ def main():
         shuffle=True, worker_init_fn=data.worker_init_fn)
 
     #### Create model
-    model = SlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device)
+    model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg)
     num_params = sum(p.numel() for p in model.parameters())
 
     print('Number of parameters:')