diff --git a/Nonevisualisation_0.png b/Nonevisualisation_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..59df42cecd44a6e04b3cc578d596a74d54a5cc10
Binary files /dev/null and b/Nonevisualisation_0.png differ
diff --git a/osrt/data/obsurf.py b/osrt/data/obsurf.py
index a4a4ebf9d869be9b73a51f470ce854dd0497738a..074ac76600a9597bb6f9ab1f00b208917479e969 100644
--- a/osrt/data/obsurf.py
+++ b/osrt/data/obsurf.py
@@ -173,8 +173,6 @@ class Clevr3dDataset(Dataset):
             target_pixels = all_pixels
             target_masks = all_masks
 
-        print(f"Final input_image : {input_images.shape} and type {type(input_images)}")
-        print(f"Final input_masks : {input_masks.shape} and type {type(input_masks)}")
 
         result = {
             'input_images':         input_images,         # [1, 3, h, w]
@@ -212,6 +210,7 @@ class Clevr2dDataset(Dataset):
         self.max_objects = max_objects
 
         self.max_num_entities = 11
+        self.rescale = 128
 
         self.start_idx, self.end_idx = {'train': (0, 70000),
                                         'val': (70000, 75000),
@@ -241,9 +240,8 @@ class Clevr2dDataset(Dataset):
         img = img[..., :3].astype(np.float32) / 255
 
         input_image = crop_center(img, 192) 
-        input_image = F.interpolate(torch.tensor(input_image).permute(2, 0, 1).unsqueeze(0), size=128)
-        input_image = input_image.squeeze(0).permute(1, 2, 0)
-
+        input_image = F.interpolate(torch.tensor(input_image).permute(2, 0, 1).unsqueeze(0), size=self.rescale)
+        input_image = input_image.squeeze(0)
 
         mask_path = os.path.join(self.path, 'masks', f'masks_{scene_idx}_0.png')
         mask_idxs = imageio.imread(mask_path)
@@ -257,10 +255,12 @@ class Clevr2dDataset(Dataset):
         input_masks = crop_center(torch.tensor(masks), 192)
         input_masks = F.interpolate(input_masks.permute(2, 0, 1).unsqueeze(0), size=128)
         input_masks = input_masks.squeeze(0).permute(1, 2, 0)
+        target_masks = np.reshape(input_masks, (self.rescale*self.rescale, self.max_num_entities))
 
         result = {
-            'input_image':          input_image,         # [3, h, w]
+            'input_images':          input_image,         # [3, h, w]
             'input_masks':          input_masks,         # [h, w, self.max_num_entities]
+            'target_masks':         target_masks,        # [h*w, self.max_num_entities]
             'sceneid':              idx,                 # int
         }
 
diff --git a/osrt/model.py b/osrt/model.py
index b1ede9301008f67bfe361b9d842d71ec0222eb88..aac7276fe9af96eb47afa5dc5959bccf90583dda 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -4,7 +4,7 @@ from torch import nn
 import torch
 import torch.nn.functional as F
 import torch.optim as optim
-import torch.optim.lr_scheduler as lr_scheduler
+from torch.optim.lr_scheduler import LambdaLR
 
 import numpy as np
 
@@ -52,18 +52,17 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
     Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
     """
 
-    def __init__(self, resolution, num_slots, num_iterations, cfg):
+    def __init__(self, resolution, num_slots, cfg):
         """Builds the Slot Attention-based auto-encoder.
 
         Args:
         resolution: Tuple of integers specifying width and height of input image.
         num_slots: Number of slots in Slot Attention.
-        num_iterations: Number of iterations in Slot Attention.
         """
         super().__init__()
         self.resolution = resolution
         self.num_slots = num_slots
-        self.num_iterations = num_iterations
+        self.num_iterations = cfg["model"]["iters"]
 
         self.criterion = nn.MSELoss()
 
@@ -79,17 +78,17 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         if model_type == 'sa':
             self.slot_attention = SlotAttention(
                 num_slots=self.num_slots,
-                input_dim=64,
-                slot_dim=64,
-                hidden_dim=128,
+                input_dim=cfg["model"]["input_dim"],
+                slot_dim=cfg["model"]["slot_dim"],
+                hidden_dim=cfg["model"]["hidden_dim"],
                 iters=self.num_iterations)
         elif model_type == 'tsa':
             # We set the same number of inside parameters
             self.slot_attention = TransformerSlotAttention(
                 num_slots=self.num_slots,
-                input_dim=64,
-                slot_dim=64,
-                hidden_dim=128,
+                input_dim=cfg["model"]["input_dim"],
+                slot_dim=cfg["model"]["slot_dim"],
+                hidden_dim=cfg["model"]["hidden_dim"],
                 depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model
 
     def forward(self, image):
@@ -102,7 +101,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
             it_since_peak = it - self.peak_it
             return self.peak_lr * (self.decay_rate ** (it_since_peak / self.decay_it))
         optimizer = optim.Adam(self.parameters(), lr=0)
-        scheduler = optim.LambdaLR(optimizer, lr_lambda=lr_func)
+        scheduler = LambdaLR(optimizer, lr_lambda=lr_func)
         return {
             'optimizer': optimizer, 
             'lr_scheduler': {
@@ -124,57 +123,80 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
 
         recons, masks = x.split((3, 1), dim = 2)
         masks = masks.softmax(dim = 1)
-        recon_combined = (recons * masks).sum(dim = 1)
-        
+        recon_combined = (recons * masks).sum(dim = 1)        
         
         return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, attn_shape) if attn_slotwise is not None else None
 
     def training_step(self, batch, batch_idx):
         """Perform a single training step."""
         input_image = torch.squeeze(batch.get('input_images'), dim=1) # Delete dim 1 if only one view
-        input_image = F.interpolate(input_image, size=128)
+        true_masks = torch.squeeze(batch.get('target_masks'), dim=1)
 
         # Get the prediction of the model and compute the loss.
         preds = self.one_step(input_image)
         recon_combined, recons, masks, slots, _ = preds
         loss_value = self.criterion(recon_combined, input_image)
-        del recons, masks, slots  # Unused.
+        del recons, slots  # Unused.
+        masks = masks.squeeze(2)
+        B, num_ob, H, W = masks.shape
+        masks = torch.reshape(masks, (B, num_ob, H*W))
+        fg_ari = compute_adjusted_rand_index(true_masks.transpose(1, 2)[:, 1:],
+                                                            masks).mean()
+        del masks
+        psnr = mse2psnr(loss_value)
 
-        fg_ari = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
-                                                            pred_seg.transpose(1, 2))
         self.log('train_mse', loss_value, on_epoch=True)
-        
+        self.log('train_fg_ari', fg_ari, on_epoch=True)
+        self.log('train_psnr', psnr, on_epoch=True)
+
+        #self.print(f"Training metrics, MSE: {loss_value}, FG-ARI: {fg_ari}, PSNR: {psnr.item()}")
         return {'loss': loss_value}
     
     def validation_step(self, batch, batch_idx):
         """Perform a single eval step."""
         input_image = torch.squeeze(batch.get('input_images'), dim=1)
-        input_image = F.interpolate(input_image, size=128)
+        true_masks = torch.squeeze(batch.get('target_masks'), dim=1) 
 
         # Get the prediction of the model and compute the loss.
         preds = self.one_step(input_image)
         recon_combined, recons, masks, slots, _ = preds
         loss_value = self.criterion(recon_combined, input_image)
-        del recons, masks, slots  # Unused.
+        del recons, slots  # Unused.
+        masks = masks.squeeze(2)
+        B, num_ob, H, W = masks.shape
+        masks = torch.reshape(masks, (B, num_ob, H*W))
+        fg_ari = compute_adjusted_rand_index(true_masks.transpose(1, 2)[:, 1:],
+                                                            masks).mean()
+        del masks
         psnr = mse2psnr(loss_value)
-        self.log('val_mse', loss_value)
-        self.log('s', psnr)
-        self.print(f"Validation metrics, MSE: {loss_value} PSNR: {psnr}")
+        self.log('val_mse', loss_value, on_epoch=True)
+        self.log('val_fg_ari', fg_ari, on_epoch=True)
+        self.log('val_psnr', psnr, on_epoch=True)
+
+        #self.print(f"Validation metrics, MSE: {loss_value}, FG-ARI: {fg_ari}, PSNR: {psnr.item()}")
         return {'loss': loss_value, 'val_psnr': psnr.item()}
 
     def test_step(self, batch, batch_idx):
         """Perform a single eval step."""
         input_image = torch.squeeze(batch.get('input_images'), dim=1)
-        input_image = F.interpolate(input_image, size=128)
+        true_masks = torch.squeeze(batch.get('target_masks'), dim=1) 
 
         # Get the prediction of the model and compute the loss.
         preds = self.one_step(input_image)
         recon_combined, recons, masks, slots, _ = preds
         loss_value = self.criterion(recon_combined, input_image)
-        del recons, masks, slots  # Unused.
+        del recons, slots  # Unused.
+        masks = masks.squeeze(2)
+        B, num_ob, H, W = masks.shape
+        masks = torch.reshape(masks, (B, num_ob, H*W))
+        fg_ari = compute_adjusted_rand_index(true_masks.transpose(1, 2)[:, 1:],
+                                                            masks).mean()
+        del masks
         psnr = mse2psnr(loss_value)
         self.log('test_loss', loss_value)
+        self.log('test_fg_ari', fg_ari)
         self.log('test_psnr', psnr)
 
+        #self.print(f"Test metrics, MSE: {loss_value}, FG-ARI: {fg_ari}, PSNR: {psnr.item()}")
         return {'loss': loss_value, 'test_psnr': psnr.item()}
     
\ No newline at end of file
diff --git a/quick_test.py b/quick_test.py
index a3aeaa5f7229cedeed6c670def8191a4ba76ca49..d63417171cf34291ea8a66fbee23435b2a972d00 100644
--- a/quick_test.py
+++ b/quick_test.py
@@ -11,6 +11,7 @@ train_dataset = data.get_dataset('train', cfg['data'])
 train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0,shuffle=True)
 
 for val in train_loader:
+    print(f"Shape masks {val['input_masks'].shape}")
     fig, axes = plt.subplots(2, 2)
     axes[0][0].imshow(val['input_image'][0])
     axes[0][1].imshow(val['input_masks'][0][:, :, 0])
diff --git a/runs/clevr/slot_att/config.yaml b/runs/clevr/slot_att/config.yaml
index 38fecb2fc2259efd23d100888a35cecb3a23e50a..d53f13d186af8e2157e0ae1e0984ec58c8a9642b 100644
--- a/runs/clevr/slot_att/config.yaml
+++ b/runs/clevr/slot_att/config.yaml
@@ -4,18 +4,21 @@ model:
   num_slots: 10
   iters: 3
   model_type: sa
+  input_dim: 64
+  slot_dim: 64
+  hidden_dim: 128
+  iters: 3
 training:
   num_workers: 2 
   num_gpus: 1
   batch_size: 8 
   max_it: 333000000
   warmup_it: 10000
+  lr_warmup: 5000
   decay_rate: 0.5
   decay_it: 100000
+  visualize_every: 5000
   validate_every: 5000
   checkpoint_every: 1000
-  print_every: 10
-  visualize_every: 5000
   backup_every: 25000
-  lr_warmup: 5000
 
diff --git a/runs/clevr/slot_att/config_tsa.yaml b/runs/clevr/slot_att/config_tsa.yaml
index d6f40290c2f0dab960c4219647614be3f9f111a1..5c9592739a292f0557e5386040dc180203095b78 100644
--- a/runs/clevr/slot_att/config_tsa.yaml
+++ b/runs/clevr/slot_att/config_tsa.yaml
@@ -4,16 +4,19 @@ model:
   num_slots: 10
   iters: 3
   model_type: tsa
+  input_dim: 64
+  slot_dim: 64
+  hidden_dim: 128
+  iters: 3
 training:
   num_workers: 2 
   num_gpus: 1
-  batch_size: 32
+  batch_size: 8
   max_it: 333000000
   warmup_it: 10000
+  lr_warmup: 5000
   decay_rate: 0.5
   decay_it: 100000
-  lr_warmup: 5000
-  print_every: 10
   visualize_every: 5000
   validate_every: 5000
   checkpoint_every: 1000
diff --git a/train_sa.py b/train_sa.py
index 86de00719beb8e2fb6b4db8600220c7aeba8acaf..c1e9daaac966da48fd4cd98ca77e7a3c06e316e3 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -46,24 +46,24 @@ def main():
     num_slots = cfg["model"]["num_slots"]
     num_iterations = cfg["model"]["iters"]
     num_train_steps = cfg["training"]["max_it"]
+    num_workers = cfg["training"]["num_workers"]
     resolution = (128, 128)
     
     print(f"Number of CPU Cores : {os.cpu_count()}")
 
     #### Create datasets
     train_dataset = data.get_dataset('train', cfg['data'])
-    val_every = val_every // len(train_dataset)
     train_loader = DataLoader(
-        train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"]-9,
+        train_dataset, batch_size=batch_size, num_workers=num_workers-9 if num_workers > 9 else 0,
         shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
     
     val_dataset = data.get_dataset('val', cfg['data'])
     val_loader = DataLoader(
-        val_dataset, batch_size=batch_size, num_workers=8,
+        val_dataset, batch_size=batch_size, num_workers=8 if num_workers > 9 else 0,
         shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
 
     #### Create model
-    model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
+    model = LitSlotAttentionAutoEncoder(resolution, num_slots, cfg=cfg)
 
     if args.ckpt:
         checkpoint = torch.load(args.ckpt)
@@ -73,7 +73,7 @@ def main():
         monitor="val_psnr",
         mode="max",
         dirpath="./checkpoints",
-        filename="ckpt-" +  str(cfg["data"]["dataset"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}",
+        filename="ckpt-" +  str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}",
         save_weights_only=True, # don't save optimizer states nor lr-scheduler, ...
         every_n_train_steps=cfg["training"]["checkpoint_every"]
     )
@@ -89,6 +89,7 @@ def main():
                          callbacks=[checkpoint_callback, early_stopping],
                          log_every_n_steps=100, 
                          val_check_interval=cfg["training"]["validate_every"],
+                         check_val_every_n_epoch=None,
                          max_steps=num_train_steps, 
                          enable_model_summary=True)