diff --git a/.visualisation_1639.png b/.visualisation_1639.png
index fece79284e2d145e76f53cc37c395ef07a98cebf..4a022d7efa8d6d435d123562b5d9dcbd0337532a 100644
Binary files a/.visualisation_1639.png and b/.visualisation_1639.png differ
diff --git a/compile_video.py b/compile_video.py
index 8b6de1aa9846bd1266a1c82f3e5cecd5ffd4dfaf..f6b5de30cb7f55ed44404e3747e6a001736fae3b 100644
--- a/compile_video.py
+++ b/compile_video.py
@@ -8,7 +8,7 @@ from tqdm import tqdm
 import argparse, os, subprocess
 from os.path import join
 
-from osrt.utils.visualize import setup_axis, background_image
+from osrt.utils.visualization_utils import setup_axis, background_image
 
 def compile_video_plot(path, frames=False, num_frames=1000000000):
 
diff --git a/eval_sa.py b/eval_sa.py
deleted file mode 100644
index 858462dc755c04c79e1a287b98734f28e606d21d..0000000000000000000000000000000000000000
--- a/eval_sa.py
+++ /dev/null
@@ -1,55 +0,0 @@
-from osrt import data
-from osrt.model import SlotAttentionAutoEncoder
-import torch
-import matplotlib.pyplot as plt
-from PIL import Image as Image
-import argparse
-import yaml
-from torch.utils.data import DataLoader
-import torch.nn.functional as F
-from osrt.utils.visualize import visualize_slot_attention
-
-if __name__ == "__main__":
-    # Arguments
-    parser = argparse.ArgumentParser(
-        description='Train a 3D scene representation model.'
-    )
-    parser.add_argument('config', type=str, help="Where to save the checkpoints.")
-    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
-    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
-    parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
-
-    args = parser.parse_args()
-    with open(args.config, 'r') as f:
-        cfg = yaml.load(f, Loader=yaml.CLoader)
-
-    # Hyperparameters.
-    seed = 0
-    batch_size = 1
-    num_slots = 7
-    num_iterations = 3
-    resolution = (128, 128)
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-    model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations)
-    model = torch.load('./ckpt.pth')['network']
-    print(model)
-    model.eval()
-
-
-
-    eval_dataset = data.get_dataset('train', cfg['data'])
-    eval_loader = DataLoader(
-        eval_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True,
-        shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True)
-
-    model = model.to(device)
-
-    image = torch.squeeze(next(iter(eval_loader)).get('input_images').to(device), dim=1)
-    image = F.interpolate(image, size=128)
-    image = image.to(device)
-    recon_combined, recons, masks, slots = model(image)
-
-    visualize_slot_attention(num_slots, image, recon_combined, recons, masks)
-
-    
-
diff --git a/evaluate_sa.py b/evaluate_sa.py
new file mode 100644
index 0000000000000000000000000000000000000000..af75b4bb89159700ede32361e378238a24693ac0
--- /dev/null
+++ b/evaluate_sa.py
@@ -0,0 +1,61 @@
+
+import argparse
+import yaml
+
+from osrt.model import LitSlotAttentionAutoEncoder
+from osrt import data
+
+from torch.utils.data import DataLoader
+
+import lightning as pl
+from lightning.pytorch.loggers.wandb import WandbLogger
+from lightning.pytorch.callbacks import ModelCheckpoint
+
+import torch
+
+def main():
+    # Arguments
+    parser = argparse.ArgumentParser(
+        description='Train a 3D scene representation model.'
+    )
+    parser.add_argument('config', type=str, help="Where to save the checkpoints.")
+    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
+    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
+    parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
+
+    args = parser.parse_args()
+    with open(args.config, 'r') as f:
+        cfg = yaml.load(f, Loader=yaml.CLoader)
+
+    ### Set random seed.
+    pl.seed_everything(42, workers=True)
+
+    ### Hyperparameters of the model.
+    batch_size = cfg["training"]["batch_size"]
+    num_gpus = cfg["training"]["num_gpus"]
+    num_slots = cfg["model"]["num_slots"]
+    num_iterations = cfg["model"]["iters"]
+    resolution = (128, 128)
+    
+    #### Create datasets
+    test_dataset = data.get_dataset('val', cfg['data'])
+    test_dataloader = DataLoader(
+        test_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
+        shuffle=True, worker_init_fn=data.worker_init_fn)
+
+    #### Create model
+    model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
+    checkpoint = torch.load(args.ckpt)
+
+    model.load_state_dict(checkpoint['state_dict'])
+    model.eval()
+
+    trainer = pl.Trainer(accelerator="gpu", devices=num_gpus,
+                         strategy="auto")
+    
+    trainer.validate(model, dataloaders=test_dataloader)
+
+                
+if __name__ == "__main__":
+    main()
+
diff --git a/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0 b/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0
new file mode 100644
index 0000000000000000000000000000000000000000..ddc38a00a26221598c3c339d50d8e4418047e23a
Binary files /dev/null and b/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0 differ
diff --git a/lightning_logs/version_0/hparams.yaml b/lightning_logs/version_0/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93
--- /dev/null
+++ b/lightning_logs/version_0/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0 b/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0
new file mode 100644
index 0000000000000000000000000000000000000000..98f994d92a84a7072f246483250046a883e74f30
Binary files /dev/null and b/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0 differ
diff --git a/lightning_logs/version_1/hparams.yaml b/lightning_logs/version_1/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93
--- /dev/null
+++ b/lightning_logs/version_1/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/osrt/model.py b/osrt/model.py
index 71511e635ca2dc8f9b8695bbab3a9bc2314cfa16..4b04c5b0ee3fd8d523554e9bcdec6e1a7460a002 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -115,7 +115,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         x = self.encoder_pos(x)
         x = self.mlp(self.layer_norm(x))
         
-        slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots)
+        slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
         x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
         x = self.decoder_pos(x)
         x = self.decoder_cnn(x.movedim(-1, 1))
@@ -131,6 +131,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
     
     def configure_optimizers(self) -> Any:
+        print(self.parameters())
         optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
         return optimizer
     
@@ -187,3 +188,20 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
 
         return {'loss': loss_value, 'val_psnr': psnr.item()}
 
+    def test_step(self, batch, batch_idx):
+        """Perform a single eval step."""
+        input_image = torch.squeeze(batch.get('input_images'), dim=1)
+        input_image = F.interpolate(input_image, size=128)
+
+        # Get the prediction of the model and compute the loss.
+        preds = self.one_step(input_image)
+        recon_combined, recons, masks, slots, _ = preds
+        #input_image = input_image.permute(0, 2, 3, 1)
+        loss_value = self.criterion(recon_combined, input_image)
+        del recons, masks, slots  # Unused.
+        psnr = mse2psnr(loss_value)
+        self.log('test_loss', loss_value)
+        self.log('test_psnr', psnr)
+
+        return {'loss': loss_value, 'test_psnr': psnr.item()}
+    
\ No newline at end of file
diff --git a/osrt/trainer.py b/osrt/trainer.py
index 01eb17dc2ee770a4ed623811bafa5479a24708cf..6afa1745b4ba75b4f78add713305a47843b4e0d9 100644
--- a/osrt/trainer.py
+++ b/osrt/trainer.py
@@ -3,7 +3,7 @@ import torch.distributed as dist
 import numpy as np
 from tqdm import tqdm
 
-import osrt.utils.visualize as vis
+import osrt.utils.visualization_utils as vis
 from osrt.utils.common import mse2psnr, reduce_dict, gather_all, compute_adjusted_rand_index
 from osrt.utils import nerf
 from osrt.utils.common import get_rank, get_world_size
diff --git a/osrt/utils/visualize.py b/osrt/utils/visualization_utils.py
similarity index 95%
rename from osrt/utils/visualize.py
rename to osrt/utils/visualization_utils.py
index 677ad6db5179949b08c11b5a74c8d5be4aaa85d4..93f8c4e7013dacc84d2d9ba86d77d518a4e8c30a 100644
--- a/osrt/utils/visualize.py
+++ b/osrt/utils/visualization_utils.py
@@ -95,18 +95,19 @@ def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, fo
     recons = recons.squeeze(0)
     masks = masks.squeeze(0)
     image = image.permute(1,2,0).cpu().numpy()
-    recon_combined = recon_combined.cpu().detach().numpy()
+    recon_combined = recon_combined.permute(1,2,0).cpu().detach().numpy()
     recons = recons.cpu().detach().numpy()
     masks = masks.cpu().detach().numpy()
 
     # Extract data and put it on a plot
     ax[0].imshow(image)
     ax[0].set_title('Image')
-    ax[1].imshow(recon_combined)
+    print(image)
+    ax[1].imshow((recon_combined * 255).astype(np.uint8))
     ax[1].set_title('Recon.')
     for i in range(6):
         picture = recons[i] * masks[i] + (1 - masks[i])
-        ax[i + 2].imshow(picture)
+        ax[i + 2].imshow(picture.transpose(1,2,0))
         ax[i + 2].set_title('Slot %s' % str(i + 1))
     for i in range(len(ax)):
         ax[i].grid(False)
diff --git a/render.py b/render.py
index 355553f656c26943763d2b7f4b1bac877266dc3a..b4edb0bfb3167e992f7010c3f813939b4b023575 100644
--- a/render.py
+++ b/render.py
@@ -9,7 +9,7 @@ from tqdm import tqdm
 
 from osrt.data import get_dataset
 from osrt.checkpoint import Checkpoint
-from osrt.utils.visualize import visualize_2d_cluster, get_clustering_colors
+from osrt.utils.visualization_utils import visualize_2d_cluster, get_clustering_colors
 from osrt.utils.nerf import rotate_around_z_axis_torch, get_camera_rays, transform_points_torch, get_extrinsic_torch
 from osrt.model import OSRT
 from osrt.trainer import SRTTrainer
diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml
index 1164186692f69964292e15c2bc6bd2ea9ed6a024..161aaf628a68e847d514e2e78bdf5867ceb2167d 100644
--- a/runs/clevr3d/slot_att/config.yaml
+++ b/runs/clevr3d/slot_att/config.yaml
@@ -6,8 +6,8 @@ model:
   model_type: sa
 training:
   num_workers: 2 
-  num_gpus: 8
-  batch_size: 32 
+  num_gpus: 1
+  batch_size: 64 
   max_it: 333000000
   warmup_it: 10000
   decay_rate: 0.5
diff --git a/train_sa.py b/train_sa.py
index 25f5c124da5a746f89a7aaf736d3c2ce42825306..ad03192032f1e171721ee3f801dffa6d1400182d 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -7,7 +7,8 @@ import yaml
 
 from osrt.model import LitSlotAttentionAutoEncoder
 from osrt import data
-from osrt.utils.visualize import visualize_slot_attention
+from osrt.utils.visualization_utils import visualize_slot_attention
+from osrt.utils.common import mse2psnr
 
 from torch.utils.data import DataLoader
 import torch.nn.functional as F
@@ -70,10 +71,30 @@ def main():
 
     trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", 
                          default_root_dir="./logs", logger=WandbLogger(project="slot-att") if args.wandb else None,
-                         strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "default", callbacks=[checkpoint_callback],
+                         strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", callbacks=[checkpoint_callback],
                          log_every_n_steps=100, max_steps=num_train_steps)
 
     trainer.fit(model, train_loader, val_loader)
+
+
+
+    #### Create datasets
+    vis_dataset = data.get_dataset('train', cfg['data'])
+    vis_loader = DataLoader(
+        vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"],
+        shuffle=True, worker_init_fn=data.worker_init_fn)
+    
+    device = model.device
+
+    #### Visualize image
+    image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
+    image = F.interpolate(image, size=128)
+    image = image.to(device)
+    recon_combined, recons, masks, slots, _ = model(image)
+    loss = nn.MSELoss()
+    loss_value = loss(recon_combined, image)
+    psnr = mse2psnr(loss_value)
+    visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True)
                 
 if __name__ == "__main__":
     main()
diff --git a/visualise.py b/visualize_sa.py
similarity index 63%
rename from visualise.py
rename to visualize_sa.py
index 19d311b6cd7703d72366b3447465518d659c3775..9153006d27592d4d0961fc60f246f29d60505ef6 100644
--- a/visualise.py
+++ b/visualize_sa.py
@@ -8,7 +8,7 @@ import yaml
 
 from osrt.model import LitSlotAttentionAutoEncoder
 from osrt import data
-from osrt.utils.visualize import visualize_slot_attention
+from osrt.utils.visualization_utils import visualize_slot_attention
 from osrt.utils.common import mse2psnr
 
 from torch.utils.data import DataLoader
@@ -37,47 +37,33 @@ def main():
     ### Hyperparameters of the model.
     num_slots = cfg["model"]["num_slots"]
     num_iterations = cfg["model"]["iters"]
-    base_learning_rate = 0.0004
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
     resolution = (128, 128)
     
     #### Create datasets
-    
-    vis_dataset = data.get_dataset('test', cfg['data'])
+    vis_dataset = data.get_dataset('train', cfg['data'])
     vis_loader = DataLoader(
         vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"],
         shuffle=True, worker_init_fn=data.worker_init_fn)
 
     #### Create model
-    model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg)
-    num_params = sum(p.numel() for p in model.parameters())
-
-    print('Number of parameters:')
-    print(f'Model slot attention: {num_params}')
-
-    optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08)
+    model = LitSlotAttentionAutoEncoder(resolution, 6, num_iterations, cfg=cfg).to(device)
+    checkpoint = torch.load(args.ckpt)
+    
+    model.load_state_dict(checkpoint['state_dict'])
 
-    ckpt = {
-        'network': model,
-        'optimizer': optimizer,
-        'global_step': 1639
-    }
-    #ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth')
-    """ckpt = torch.load('~/ckpt.pth')
-    model = ckpt['network']"""
-    model.load_state_dict(torch.load('/home/achapin/ckpt_1639.pth')["model_state_dict"])
+    model.eval()
 
+    #### Visualize image
     image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
     image = F.interpolate(image, size=128)
     image = image.to(device)
-    recon_combined, recons, masks, slots = model(image)
+    recon_combined, recons, masks, _, _ = model(image)
     loss = nn.MSELoss()
-    input_image = image.permute(0, 2, 3, 1)
-    loss_value = loss(recon_combined, input_image)
+    loss_value = loss(recon_combined, image)
     psnr = mse2psnr(loss_value)
-    print(f"MSE value : {loss_value} VS PSNR {psnr}")
-    visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=1639, save_file=True)
+    visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True)
                   
 if __name__ == "__main__":
     main()
\ No newline at end of file