diff --git a/osrt/data/__init__.py b/osrt/data/__init__.py
index 428145ed318953710fabcc7f2fc677d64ea20356..96d64f2bae157df08fdac791a9e16baf5eedc92a 100644
--- a/osrt/data/__init__.py
+++ b/osrt/data/__init__.py
@@ -1,5 +1,5 @@
 from osrt.data.core import get_dataset, worker_init_fn
 from osrt.data.nmr import NMRDataset
 from osrt.data.multishapenet import MultishapenetDataset 
-from osrt.data.obsurf import Clevr3dDataset
+from osrt.data.obsurf import Clevr3dDataset, Clevr2dDataset
 
diff --git a/osrt/data/core.py b/osrt/data/core.py
index 81327f04901a8384fcdba2be9b4da7b09e8c1c0e..3b8886ebe8c87662e814f8908a085cd560fae565 100644
--- a/osrt/data/core.py
+++ b/osrt/data/core.py
@@ -43,6 +43,8 @@ def get_dataset(mode, cfg, max_len=None, full_scale=False):
         dataset = data.Clevr3dDataset(dataset_folder, mode, points_per_item=points_per_item,
                                       shapenet=False, max_len=max_len, full_scale=full_scale,
                                       canonical_view=canonical_view, **kwargs)
+    elif dataset_type == 'clevr2d':
+        dataset = data.Clevr2dDataset(dataset_folder, mode, **kwargs)
     elif dataset_type == 'obsurf_msn':
         dataset = data.Clevr3dDataset(dataset_folder, mode, points_per_item=points_per_item,
                                       shapenet=True, max_len=max_len, full_scale=full_scale,
diff --git a/osrt/data/obsurf.py b/osrt/data/obsurf.py
index 4a93013db5236f76161fdaf4dc6147609ba67294..a4a4ebf9d869be9b73a51f470ce854dd0497738a 100644
--- a/osrt/data/obsurf.py
+++ b/osrt/data/obsurf.py
@@ -1,11 +1,15 @@
 import numpy as np
 import imageio
-import yaml
+
 from torch.utils.data import Dataset
+import torchvision.transforms as transforms
+import torch
+from PIL import Image
 
 import os
 
 from osrt.utils.nerf import get_camera_rays, get_extrinsic, transform_points
+import torch.nn.functional as F
 
 
 def downsample(x, num_steps=1):
@@ -14,6 +18,24 @@ def downsample(x, num_steps=1):
     stride = 2**num_steps
     return x[stride//2::stride, stride//2::stride]
 
+def crop_center(image, crop_size=192):
+    height, width = image.shape[:2]
+
+    center_x = width // 2
+    center_y = height // 2
+
+    crop_size_half = crop_size // 2
+
+    # Calculate the top-left corner coordinates of the crop
+    crop_x1 = center_x - crop_size_half
+    crop_y1 = center_y - crop_size_half
+
+    # Calculate the bottom-right corner coordinates of the crop
+    crop_x2 = center_x + crop_size_half
+    crop_y2 = center_y + crop_size_half
+
+    # Crop the image
+    return image[crop_y1:crop_y2, crop_x1:crop_x2]
 
 class Clevr3dDataset(Dataset):
     def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True,
@@ -93,6 +115,7 @@ class Clevr3dDataset(Dataset):
         masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8)
         np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
 
+
         input_image = downsample(imgs[view_idx], num_steps=self.downsample)
         input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0)
 
@@ -150,6 +173,9 @@ class Clevr3dDataset(Dataset):
             target_pixels = all_pixels
             target_masks = all_masks
 
+        print(f"Final input_image : {input_images.shape} and type {type(input_images)}")
+        print(f"Final input_masks : {input_masks.shape} and type {type(input_masks)}")
+
         result = {
             'input_images':         input_images,         # [1, 3, h, w]
             'input_camera_pos':     input_camera_pos,     # [1, 3]
@@ -168,3 +194,75 @@ class Clevr3dDataset(Dataset):
         return result
 
 
+class Clevr2dDataset(Dataset):
+    def __init__(self, path, mode, max_objects=6):
+        """ Loads the dataset used in the ObSuRF paper.
+
+        They may be downloaded at: https://stelzner.github.io/obsurf
+        Args:
+            path (str): Path to dataset.
+            mode (str): 'train', 'val', or 'test'.
+            full_scale (bool): Return all available target points, instead of sampling.
+            max_objects (int): Load only scenes with at most this many objects.
+            shapenet (bool): Load ObSuRF's MultiShapeNet dataset, instead of CLEVR3D.
+            downsample (int): Downsample height and width of input image by a factor of 2**downsample
+        """
+        self.path = path.replace("clevr2d", "clevr3d")
+        self.mode = mode
+        self.max_objects = max_objects
+
+        self.max_num_entities = 11
+
+        self.start_idx, self.end_idx = {'train': (0, 70000),
+                                        'val': (70000, 75000),
+                                        'test': (85000, 100000)}[mode]
+
+        self.metadata = np.load(os.path.join(self.path, 'metadata.npz'))
+        self.metadata = {k: v for k, v in self.metadata.items()}
+
+        num_objs = (self.metadata['shape'][self.start_idx:self.end_idx] > 0).sum(1)
+
+        self.idxs = np.arange(self.start_idx, self.end_idx)[num_objs <= max_objects]
+
+        dataset_name = 'CLEVR'
+        print(f'Initialized {dataset_name} {mode} set, {len(self.idxs)} examples')
+        print(self.idxs)
+
+
+    def __len__(self):
+        return len(self.idxs)
+
+    def __getitem__(self, idx, noisy=True):
+        scene_idx = idx % len(self.idxs)
+        scene_idx = self.idxs[scene_idx]
+
+        img_path = os.path.join(self.path, 'images', f'img_{scene_idx}_0.png')
+        img = np.asarray(imageio.imread(img_path))
+        img = img[..., :3].astype(np.float32) / 255
+
+        input_image = crop_center(img, 192) 
+        input_image = F.interpolate(torch.tensor(input_image).permute(2, 0, 1).unsqueeze(0), size=128)
+        input_image = input_image.squeeze(0).permute(1, 2, 0)
+
+
+        mask_path = os.path.join(self.path, 'masks', f'masks_{scene_idx}_0.png')
+        mask_idxs = imageio.imread(mask_path)
+
+        masks = np.zeros((240, 320, self.max_num_entities), dtype=np.uint8)
+
+        np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
+
+        metadata = {k: v[scene_idx] for (k, v) in self.metadata.items()}
+
+        input_masks = crop_center(torch.tensor(masks), 192)
+        input_masks = F.interpolate(input_masks.permute(2, 0, 1).unsqueeze(0), size=128)
+        input_masks = input_masks.squeeze(0).permute(1, 2, 0)
+
+        result = {
+            'input_image':          input_image,         # [3, h, w]
+            'input_masks':          input_masks,         # [h, w, self.max_num_entities]
+            'sceneid':              idx,                 # int
+        }
+
+        return result
+
diff --git a/osrt/data/ycb-video.py b/osrt/data/ycb-video.py
index aa1a1b88e8f6f4ccde51e213969cfaec3018efaf..615a13f132464ee50c9ca2f324439e79c3e52320 100644
--- a/osrt/data/ycb-video.py
+++ b/osrt/data/ycb-video.py
@@ -134,7 +134,7 @@ def downsample(x, num_steps=1):
     return x[stride//2::stride, stride//2::stride]
 
 
-class YCBVideo(Dataset):
+class YCBVideo3D(Dataset):
     def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True,
                  max_len=None, full_scale=False, shapenet=False, downsample=None):
         """ Loads the YCB-Video dataset that we have adapted.
@@ -274,4 +274,99 @@ class YCBVideo(Dataset):
 
         return result
 
+class YCBVideo2D(Dataset):
+    def __init__(self, path, mode, downsample=None):
+        """ Loads the YCB-Video dataset that we have adapted.
+
+        Args:
+            path (str): Path to dataset.
+            mode (str): 'train', 'val', or 'test'.
+            downsample (int): Downsample height and width of input image by a factor of 2**downsample
+        """
+        self.path = path
+        self.mode = mode
+        self.downsample = downsample
+
+        self.max_num_entities = 21 # max number of objects in a scene 
+
+        # TODO : set the right number here
+
+        dataset_name = 'YCB-Video'
+        print(f'Initialized {dataset_name} {mode} set')
+
+
+    def __len__(self):
+        return len(self.idxs)
+
+    def __getitem__(self, idx):
+
+        imgs = [np.asarray(imageio.imread(
+            os.path.join(self.path, 'images', f'img_{scene_idx}_{v}.png')))
+            for v in range(self.num_views)]
+
+        imgs = [img[..., :3].astype(np.float32) / 255 for img in imgs]
+
+        mask_idxs = [imageio.imread(os.path.join(self.path, 'masks', f'masks_{scene_idx}_{v}.png'))
+                    for v in range(self.num_views)]
+        masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8)
+        np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
+
+        input_image = downsample(imgs[view_idx], num_steps=self.downsample)
+        input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0)
+
+        all_rays = []
+        # TODO : find a way to get the camera poses
+        all_camera_pos = self.metadata['camera_pos'][:self.num_views].astype(np.float32)
+        all_camera_rot= self.metadata['camera_rot'][:self.num_views].astype(np.float32)
+        for i in range(self.num_views):
+            cur_rays = get_camera_rays(all_camera_pos[i], all_camera_rot[i], noisy=False) # TODO : adapt function
+            all_rays.append(cur_rays)
+        all_rays = np.stack(all_rays, 0).astype(np.float32)
+
+        input_camera_pos = all_camera_pos[view_idx]
+
+        if self.canonical:
+            track_point = np.zeros_like(input_camera_pos)  # All cameras are pointed at the origin
+            canonical_extrinsic = get_extrinsic(input_camera_pos, track_point=track_point) # TODO : adapt function
+            canonical_extrinsic = canonical_extrinsic.astype(np.float32) 
+            all_rays = transform_points(all_rays, canonical_extrinsic, translate=False) # TODO : adapt function
+            all_camera_pos = transform_points(all_camera_pos, canonical_extrinsic)
+            input_camera_pos = all_camera_pos[view_idx]
+
+        input_rays = all_rays[view_idx]
+        input_rays = downsample(input_rays, num_steps=self.downsample)
+        input_rays = np.expand_dims(input_rays, 0)
+
+        input_masks = masks[view_idx]
+        input_masks = downsample(input_masks, num_steps=self.downsample)
+        input_masks = np.expand_dims(input_masks, 0)
+
+        input_camera_pos = np.expand_dims(input_camera_pos, 0)
+
+        all_pixels = np.reshape(np.stack(imgs, 0), (self.num_views * 240 * 320, 3))
+        all_rays = np.reshape(all_rays, (self.num_views * 240 * 320, 3))
+        all_camera_pos = np.tile(np.expand_dims(all_camera_pos, 1), (1, 240 * 320, 1))
+        all_camera_pos = np.reshape(all_camera_pos, (self.num_views * 240 * 320, 3))
+        all_masks = np.reshape(masks, (self.num_views * 240 * 320, self.max_num_entities))
+
+        target_camera_pos = all_camera_pos
+        target_pixels = all_pixels
+        target_masks = all_masks
+
+        result = {
+            'input_images':         input_images,         # [1, 3, h, w]
+            'input_camera_pos':     input_camera_pos,     # [1, 3]
+            'input_rays':           input_rays,           # [1, h, w, 3]
+            'input_masks':          input_masks,          # [1, h, w, self.max_num_entities]
+            'target_pixels':        target_pixels,        # [p, 3]
+            'target_camera_pos':    target_camera_pos,    # [p, 3]
+            'target_masks':         target_masks,         # [p, self.max_num_entities]
+            'sceneid':              idx,                  # int
+        }
+
+        if self.canonical:
+            result['transform'] = canonical_extrinsic     # [3, 4] (optional)
+
+        return result
+
 
diff --git a/osrt/model.py b/osrt/model.py
index 24842c9d926bda12ab66cfb52bbe0d99c2fcd0e7..b1ede9301008f67bfe361b9d842d71ec0222eb88 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -4,6 +4,7 @@ from torch import nn
 import torch
 import torch.nn.functional as F
 import torch.optim as optim
+import torch.optim.lr_scheduler as lr_scheduler
 
 import numpy as np
 
@@ -11,7 +12,7 @@ from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
 from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
 from osrt.layers import SlotAttention, TransformerSlotAttention, Encoder, Decoder
 import osrt.layers as layers
-from osrt.utils.common import mse2psnr
+from osrt.utils.common import mse2psnr, compute_adjusted_rand_index
 
 import lightning as pl
 
@@ -69,6 +70,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         self.encoder = Encoder()
         self.decoder = Decoder()
 
+        self.peak_it = cfg["training"]["warmup_it"]
+        self.peak_lr = 1e-4
+        self.decay_rate = 0.16
+        self.decay_it = cfg["training"]["decay_it"]
+
         model_type = cfg['model']['model_type']
         if model_type == 'sa':
             self.slot_attention = SlotAttention(
@@ -87,24 +93,23 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
                 depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model
 
     def forward(self, image):
-        x = self.encoder(image)
-        slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
-        x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1)
-        
-        x = self.decoder(x)
-        x = F.interpolate(x, image.shape[-2:], mode='bilinear')
-
-        x = x.unflatten(0, (len(image), len(x) // len(image)))
-
-        recons, masks = x.split((3, 1), dim = 2)
-        masks = masks.softmax(dim = 1)
-        recon_combined = (recons * masks).sum(dim = 1)
-
-        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None
+        return self.one_step(image)
     
     def configure_optimizers(self) -> Any:
-        optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
-        return optimizer
+        def lr_func(it):
+            if it < self.peak_it:  # Warmup period
+                return self.peak_lr * (it / self.peak_it)
+            it_since_peak = it - self.peak_it
+            return self.peak_lr * (self.decay_rate ** (it_since_peak / self.decay_it))
+        optimizer = optim.Adam(self.parameters(), lr=0)
+        scheduler = optim.LambdaLR(optimizer, lr_lambda=lr_func)
+        return {
+            'optimizer': optimizer, 
+            'lr_scheduler': {
+                'scheduler': scheduler, 
+                'interval': 'step' # Update the scheduler at each step
+            }
+        }
     
     def one_step(self, image):
         x = self.encoder(image)
@@ -126,18 +131,19 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
 
     def training_step(self, batch, batch_idx):
         """Perform a single training step."""
-        input_image = torch.squeeze(batch.get('input_images'), dim=1)
+        input_image = torch.squeeze(batch.get('input_images'), dim=1) # Delete dim 1 if only one view
         input_image = F.interpolate(input_image, size=128)
 
         # Get the prediction of the model and compute the loss.
         preds = self.one_step(input_image)
         recon_combined, recons, masks, slots, _ = preds
-        #input_image = input_image.permute(0, 2, 3, 1)
         loss_value = self.criterion(recon_combined, input_image)
         del recons, masks, slots  # Unused.
 
+        fg_ari = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
+                                                            pred_seg.transpose(1, 2))
         self.log('train_mse', loss_value, on_epoch=True)
-
+        
         return {'loss': loss_value}
     
     def validation_step(self, batch, batch_idx):
@@ -148,13 +154,12 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         # Get the prediction of the model and compute the loss.
         preds = self.one_step(input_image)
         recon_combined, recons, masks, slots, _ = preds
-        #input_image = input_image.permute(0, 2, 3, 1)
         loss_value = self.criterion(recon_combined, input_image)
         del recons, masks, slots  # Unused.
         psnr = mse2psnr(loss_value)
         self.log('val_mse', loss_value)
-        self.log('val_psnr', psnr)
-
+        self.log('s', psnr)
+        self.print(f"Validation metrics, MSE: {loss_value} PSNR: {psnr}")
         return {'loss': loss_value, 'val_psnr': psnr.item()}
 
     def test_step(self, batch, batch_idx):
@@ -165,7 +170,6 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
         # Get the prediction of the model and compute the loss.
         preds = self.one_step(input_image)
         recon_combined, recons, masks, slots, _ = preds
-        #input_image = input_image.permute(0, 2, 3, 1)
         loss_value = self.criterion(recon_combined, input_image)
         del recons, masks, slots  # Unused.
         psnr = mse2psnr(loss_value)
diff --git a/outputs/visualisation_10.png b/outputs/visualisation_10.png
new file mode 100644
index 0000000000000000000000000000000000000000..b0c2093c157c26767aa27566db0e917f9b4d9184
Binary files /dev/null and b/outputs/visualisation_10.png differ
diff --git a/outputs/visualisation_2.png b/outputs/visualisation_2.png
new file mode 100644
index 0000000000000000000000000000000000000000..97fa47ad3757362be19f32b08b647a3e2da265dd
Binary files /dev/null and b/outputs/visualisation_2.png differ
diff --git a/outputs/visualisation_3.png b/outputs/visualisation_3.png
new file mode 100644
index 0000000000000000000000000000000000000000..c5cf616373e92e744c64439f228bf0ea79ce6a0c
Binary files /dev/null and b/outputs/visualisation_3.png differ
diff --git a/outputs/visualisation_4.png b/outputs/visualisation_4.png
new file mode 100644
index 0000000000000000000000000000000000000000..94be3e3e76c2bc7184ba032ddc690506402f08b8
Binary files /dev/null and b/outputs/visualisation_4.png differ
diff --git a/outputs/visualisation_5.png b/outputs/visualisation_5.png
new file mode 100644
index 0000000000000000000000000000000000000000..a3f565c9df06d067e0c4ba19cfa19854e92c6636
Binary files /dev/null and b/outputs/visualisation_5.png differ
diff --git a/outputs/visualisation_6.png b/outputs/visualisation_6.png
new file mode 100644
index 0000000000000000000000000000000000000000..aa4283d62f1cefcc89d59aa2253ab3363c7b9f96
Binary files /dev/null and b/outputs/visualisation_6.png differ
diff --git a/outputs/visualisation_7.png b/outputs/visualisation_7.png
new file mode 100644
index 0000000000000000000000000000000000000000..8947de651e61d0ad5559c3cd120febd0c7a495d2
Binary files /dev/null and b/outputs/visualisation_7.png differ
diff --git a/quick_test.py b/quick_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3aeaa5f7229cedeed6c670def8191a4ba76ca49
--- /dev/null
+++ b/quick_test.py
@@ -0,0 +1,19 @@
+import yaml
+from osrt import data
+from torch.utils.data import DataLoader
+import matplotlib.pyplot as plt
+
+with open("runs/clevr/slot_att/config.yaml", 'r') as f:
+    cfg = yaml.load(f, Loader=yaml.CLoader)
+
+
+train_dataset = data.get_dataset('train', cfg['data'])
+train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0,shuffle=True)
+
+for val in train_loader:
+    fig, axes = plt.subplots(2, 2)
+    axes[0][0].imshow(val['input_image'][0])
+    axes[0][1].imshow(val['input_masks'][0][:, :, 0])
+    axes[1][0].imshow(val['input_image'][1])
+    axes[1][1].imshow(val['input_masks'][1][:, :, 0])
+    plt.show()
\ No newline at end of file
diff --git a/runs/clevr/slot_att/config.yaml b/runs/clevr/slot_att/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..38fecb2fc2259efd23d100888a35cecb3a23e50a
--- /dev/null
+++ b/runs/clevr/slot_att/config.yaml
@@ -0,0 +1,21 @@
+data:
+  dataset: clevr2d
+model:
+  num_slots: 10
+  iters: 3
+  model_type: sa
+training:
+  num_workers: 2 
+  num_gpus: 1
+  batch_size: 8 
+  max_it: 333000000
+  warmup_it: 10000
+  decay_rate: 0.5
+  decay_it: 100000
+  validate_every: 5000
+  checkpoint_every: 1000
+  print_every: 10
+  visualize_every: 5000
+  backup_every: 25000
+  lr_warmup: 5000
+
diff --git a/runs/clevr/slot_att/config_tsa.yaml b/runs/clevr/slot_att/config_tsa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d6f40290c2f0dab960c4219647614be3f9f111a1
--- /dev/null
+++ b/runs/clevr/slot_att/config_tsa.yaml
@@ -0,0 +1,21 @@
+data:
+  dataset: clevr2d
+model:
+  num_slots: 10
+  iters: 3
+  model_type: tsa
+training:
+  num_workers: 2 
+  num_gpus: 1
+  batch_size: 32
+  max_it: 333000000
+  warmup_it: 10000
+  decay_rate: 0.5
+  decay_it: 100000
+  lr_warmup: 5000
+  print_every: 10
+  visualize_every: 5000
+  validate_every: 5000
+  checkpoint_every: 1000
+  backup_every: 25000
+
diff --git a/runs/clevr3d/osrt/config.yaml b/runs/clevr3d/osrt/config.yaml
deleted file mode 100644
index f11b301b7615bc2ce72702e67d8d7ab619c4f17a..0000000000000000000000000000000000000000
--- a/runs/clevr3d/osrt/config.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-
-data:
-  dataset: clevr3d
-  num_points: 2000 
-  kwargs:
-    downsample: 1
-model:
-  encoder: osrt
-  encoder_kwargs:
-    pos_start_octave: -5
-    num_slots: 6
-  decoder: slot_mixer
-  decoder_kwargs:
-    pos_start_octave: -5
-training:
-  num_workers: 2 
-  batch_size: 64 
-  model_selection_metric: psnr
-  model_selection_mode: maximize
-  print_every: 10
-  visualize_every: 5000
-  validate_every: 5000
-  checkpoint_every: 1000
-  backup_every: 25000
-  max_it: 333000000
-  decay_it: 4000000
-  lr_warmup: 5000
-
diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml
index 82c0f8e78d858da06e976697291a7189012fba51..4dceb22aa530ae0ffb6dfe5859ced635bf73c29c 100644
--- a/runs/clevr3d/slot_att/config.yaml
+++ b/runs/clevr3d/slot_att/config.yaml
@@ -12,4 +12,10 @@ training:
   warmup_it: 10000
   decay_rate: 0.5
   decay_it: 100000
+  validate_every: 5000
+  checkpoint_every: 1000
+  print_every: 10
+  visualize_every: 5000
+  backup_every: 25000
+  lr_warmup: 5000
 
diff --git a/runs/clevr3d/slot_att/config_tsa.yaml b/runs/clevr3d/slot_att/config_tsa.yaml
index f03fc97062dd583a0884019890f8fbd0778280b1..6255ab06539a5a7f63af67fc6ebd2cafaaf922e6 100644
--- a/runs/clevr3d/slot_att/config_tsa.yaml
+++ b/runs/clevr3d/slot_att/config_tsa.yaml
@@ -12,4 +12,10 @@ training:
   warmup_it: 10000
   decay_rate: 0.5
   decay_it: 100000
+  lr_warmup: 5000
+  print_every: 10
+  visualize_every: 5000
+  validate_every: 5000
+  checkpoint_every: 1000
+  backup_every: 25000
 
diff --git a/train_sa.py b/train_sa.py
index 79379b06054071619dc1fddbb3df246f51744b8c..86de00719beb8e2fb6b4db8600220c7aeba8acaf 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -1,5 +1,3 @@
-import datetime
-import time
 import torch
 import torch.nn as nn
 import argparse
@@ -12,12 +10,18 @@ from osrt.utils.common import mse2psnr
 
 from torch.utils.data import DataLoader
 import torch.nn.functional as F
-from tqdm import tqdm
+import os
 
 import lightning as pl
 from lightning.pytorch.loggers.wandb import WandbLogger
 from lightning.pytorch.callbacks import ModelCheckpoint
 
+import warnings
+from lightning.pytorch.utilities.warnings import PossibleUserWarning
+from lightning.pytorch.callbacks.early_stopping import EarlyStopping
+
+# Ignore all warnings that could be false positives : cf https://lightning.ai/docs/pytorch/stable/advanced/speed.html
+warnings.filterwarnings("ignore", category=PossibleUserWarning)
 
 def main():
     # Arguments
@@ -42,21 +46,21 @@ def main():
     num_slots = cfg["model"]["num_slots"]
     num_iterations = cfg["model"]["iters"]
     num_train_steps = cfg["training"]["max_it"]
-    warmup_steps = cfg["training"]["warmup_it"]
-    decay_rate = cfg["training"]["decay_rate"]
-    decay_steps = cfg["training"]["decay_it"]
     resolution = (128, 128)
     
+    print(f"Number of CPU Cores : {os.cpu_count()}")
+
     #### Create datasets
     train_dataset = data.get_dataset('train', cfg['data'])
+    val_every = val_every // len(train_dataset)
     train_loader = DataLoader(
-        train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
-        shuffle=True, worker_init_fn=data.worker_init_fn)
+        train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"]-9,
+        shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
     
     val_dataset = data.get_dataset('val', cfg['data'])
     val_loader = DataLoader(
-        val_dataset, batch_size=batch_size, num_workers=1,
-        shuffle=True, worker_init_fn=data.worker_init_fn)
+        val_dataset, batch_size=batch_size, num_workers=8,
+        shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
 
     #### Create model
     model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
@@ -65,26 +69,35 @@ def main():
         checkpoint = torch.load(args.ckpt)
         model.load_state_dict(checkpoint['state_dict'])
 
-
     checkpoint_callback = ModelCheckpoint(
-        save_top_k=10,
         monitor="val_psnr",
         mode="max",
-        dirpath="./checkpoints" if cfg["model"]["model_type"] == "sa" else "./checkpoints_tsa",
-        filename="slot_att-clevr3d-{epoch:02d}-psnr{val_psnr:.2f}.pth",
+        dirpath="./checkpoints",
+        filename="ckpt-" +  str(cfg["data"]["dataset"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}",
+        save_weights_only=True, # don't save optimizer states nor lr-scheduler, ...
+        every_n_train_steps=cfg["training"]["checkpoint_every"]
     )
 
-    trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", 
-                         default_root_dir="./logs", logger=WandbLogger(project="slot-att") if args.wandb else None,
-                         strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", callbacks=[checkpoint_callback],
-                         log_every_n_steps=100, max_steps=num_train_steps)
+    early_stopping = EarlyStopping(monitor="val_psnr", mode="max")
+
+    trainer = pl.Trainer(accelerator="gpu", 
+                         devices=num_gpus, 
+                         profiler="simple", 
+                         default_root_dir="./logs", 
+                         logger=WandbLogger(project="slot-att") if args.wandb else None,
+                         strategy="ddp" if num_gpus > 1 else "auto", 
+                         callbacks=[checkpoint_callback, early_stopping],
+                         log_every_n_steps=100, 
+                         val_check_interval=cfg["training"]["validate_every"],
+                         max_steps=num_train_steps, 
+                         enable_model_summary=True)
 
     trainer.fit(model, train_loader, val_loader)
 
     #### Create datasets
     vis_dataset = data.get_dataset('train', cfg['data'])
     vis_loader = DataLoader(
-        vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"],
+        vis_dataset, batch_size=1, num_workers=1,
         shuffle=True, worker_init_fn=data.worker_init_fn)
     
     device = model.device
@@ -100,21 +113,4 @@ def main():
     visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True)
                 
 if __name__ == "__main__":
-    main()
-
-
-#print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")
-
-"""
-
-if not epoch % cfg["training"]["checkpoint_every"]:
-            # Save the checkpoint of the model.
-            ckpt['global_step'] = global_step
-            ckpt['model_state_dict'] = model.state_dict()
-            torch.save(ckpt, args.ckpt + '/ckpt_' + str(global_step) + '.pth')
-            print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}")
-
-        # We visualize some test data
-        if not epoch % cfg["training"]["visualize_every"]:
-            
-"""
\ No newline at end of file
+    main()
\ No newline at end of file
diff --git a/visualize_sa.py b/visualize_sa.py
index 2e4031d46c69a38c0a82a1451ac90f6f7b52be18..b9e78d372f5cdacbefad2190120a5a99db0c4945 100644
--- a/visualize_sa.py
+++ b/visualize_sa.py
@@ -26,7 +26,7 @@ def main():
     parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
     parser.add_argument('--seed', type=int, default=0, help='Random seed.')
     parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
-    parser.add_argument('--output', type=str, default="./outputs", help='Folder in which to save images')
+    parser.add_argument('--output', type=str, default="./outputs/", help='Folder in which to save images')
     parser.add_argument('--step', type=int, default=".", help='Step of the model')
 
     args = parser.parse_args()
@@ -44,7 +44,7 @@ def main():
     resolution = (128, 128)
     
     #### Create datasets
-    vis_dataset = data.get_dataset('train', cfg['data'])
+    vis_dataset = data.get_dataset('val', cfg['data'])
     vis_loader = DataLoader(
         vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"],
         shuffle=True, worker_init_fn=data.worker_init_fn)