diff --git a/Nonevisualisation_0.png b/Nonevisualisation_0.png
deleted file mode 100644
index 59df42cecd44a6e04b3cc578d596a74d54a5cc10..0000000000000000000000000000000000000000
Binary files a/Nonevisualisation_0.png and /dev/null differ
diff --git a/evaluate_sa.py b/evaluate_sa.py
index e5e06eb7ec4990daa29ba21aead49de35b00f270..bdd4e629011f9513dab13dadd02daa04e06ece59 100644
--- a/evaluate_sa.py
+++ b/evaluate_sa.py
@@ -19,7 +19,6 @@ def main():
         description='Train a 3D scene representation model.'
     )
     parser.add_argument('config', type=str, help="Where to save the checkpoints.")
-    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
     parser.add_argument('--seed', type=int, default=0, help='Random seed.')
     parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
 
diff --git a/osrt/data/__init__.py b/osrt/data/__init__.py
index 96d64f2bae157df08fdac791a9e16baf5eedc92a..84a842d13a91f31f58c3e6e93526afa784f28f4b 100644
--- a/osrt/data/__init__.py
+++ b/osrt/data/__init__.py
@@ -2,4 +2,5 @@ from osrt.data.core import get_dataset, worker_init_fn
 from osrt.data.nmr import NMRDataset
 from osrt.data.multishapenet import MultishapenetDataset 
 from osrt.data.obsurf import Clevr3dDataset, Clevr2dDataset
+from osrt.data.ycbv import YCBVideo2D
 
diff --git a/osrt/data/core.py b/osrt/data/core.py
index 3b8886ebe8c87662e814f8908a085cd560fae565..5a8bb82fca7244cf1490205e516dbd653a7412cd 100644
--- a/osrt/data/core.py
+++ b/osrt/data/core.py
@@ -24,6 +24,7 @@ def get_dataset(mode, cfg, max_len=None, full_scale=False):
 
     if 'kwargs' in cfg:
         kwargs = cfg['kwargs']
+        print(kwargs)
     else:
         kwargs = dict()
 
@@ -45,6 +46,8 @@ def get_dataset(mode, cfg, max_len=None, full_scale=False):
                                       canonical_view=canonical_view, **kwargs)
     elif dataset_type == 'clevr2d':
         dataset = data.Clevr2dDataset(dataset_folder, mode, **kwargs)
+    elif dataset_type == 'ycb2d':
+        dataset = data.YCBVideo2D(dataset_folder, mode, **kwargs)
     elif dataset_type == 'obsurf_msn':
         dataset = data.Clevr3dDataset(dataset_folder, mode, points_per_item=points_per_item,
                                       shapenet=True, max_len=max_len, full_scale=full_scale,
@@ -62,3 +65,28 @@ def worker_init_fn(worker_id):
     base_seed = int.from_bytes(random_data, byteorder="big")
     np.random.seed(base_seed + worker_id)
 
+
+def downsample(x, num_steps=1):
+    if num_steps is None or num_steps < 1:
+        return x
+    stride = 2**num_steps
+    return x[stride//2::stride, stride//2::stride]
+
+def crop_center(image, crop_size=192):
+    height, width = image.shape[:2]
+
+    center_x = width // 2
+    center_y = height // 2
+
+    crop_size_half = crop_size // 2
+
+    # Calculate the top-left corner coordinates of the crop
+    crop_x1 = center_x - crop_size_half
+    crop_y1 = center_y - crop_size_half
+
+    # Calculate the bottom-right corner coordinates of the crop
+    crop_x2 = center_x + crop_size_half
+    crop_y2 = center_y + crop_size_half
+
+    # Crop the image
+    return image[crop_y1:crop_y2, crop_x1:crop_x2]
\ No newline at end of file
diff --git a/osrt/data/obsurf.py b/osrt/data/obsurf.py
index 074ac76600a9597bb6f9ab1f00b208917479e969..bb43841cb8ad27013d7a6e5ae4deafb7c4ca5183 100644
--- a/osrt/data/obsurf.py
+++ b/osrt/data/obsurf.py
@@ -10,32 +10,7 @@ import os
 
 from osrt.utils.nerf import get_camera_rays, get_extrinsic, transform_points
 import torch.nn.functional as F
-
-
-def downsample(x, num_steps=1):
-    if num_steps is None or num_steps < 1:
-        return x
-    stride = 2**num_steps
-    return x[stride//2::stride, stride//2::stride]
-
-def crop_center(image, crop_size=192):
-    height, width = image.shape[:2]
-
-    center_x = width // 2
-    center_y = height // 2
-
-    crop_size_half = crop_size // 2
-
-    # Calculate the top-left corner coordinates of the crop
-    crop_x1 = center_x - crop_size_half
-    crop_y1 = center_y - crop_size_half
-
-    # Calculate the bottom-right corner coordinates of the crop
-    crop_x2 = center_x + crop_size_half
-    crop_y2 = center_y + crop_size_half
-
-    # Crop the image
-    return image[crop_y1:crop_y2, crop_x1:crop_x2]
+from .core import downsample, crop_center
 
 class Clevr3dDataset(Dataset):
     def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True,
@@ -250,8 +225,6 @@ class Clevr2dDataset(Dataset):
 
         np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
 
-        metadata = {k: v[scene_idx] for (k, v) in self.metadata.items()}
-
         input_masks = crop_center(torch.tensor(masks), 192)
         input_masks = F.interpolate(input_masks.permute(2, 0, 1).unsqueeze(0), size=128)
         input_masks = input_masks.squeeze(0).permute(1, 2, 0)
diff --git a/osrt/data/ycb-video.py b/osrt/data/ycb-video.py
deleted file mode 100644
index 615a13f132464ee50c9ca2f324439e79c3e52320..0000000000000000000000000000000000000000
--- a/osrt/data/ycb-video.py
+++ /dev/null
@@ -1,372 +0,0 @@
-import numpy as np
-import imageio
-import yaml
-from torch.utils.data import Dataset
-
-import os
-
-def get_extrinsic(camera_pos, rays=None, track_point=None, fourxfour=True):
-    """ Returns extrinsic matrix mapping world to camera coordinates.
-    Args:
-        camera_pos (np array [3]): Camera position.
-        track_point (np array [3]): Point on which the camera is fixated.
-        rays (np array [h, w, 3]): Rays eminating from the camera. Used to determine track_point
-            if it's not given.
-        fourxfour (bool): If true, a 4x4 matrix for homogeneous 3D coordinates is returned.
-            Otherwise, a 3x4 matrix is returned.
-    Returns:
-        extrinsic camera matrix (np array [4, 4] or [3, 4])
-    """
-    if track_point is None:
-        h, w, _ = rays.shape
-        if h % 2 == 0:
-            center_rays = rays[h//2 - 1:h//2 + 1]
-        else:
-            center_rays = rays[h//2:h//2+1]
-
-        if w % 2 == 0:
-            center_rays = rays[:, w//2 - 1:w//2 + 1]
-        else:
-            center_rays = rays[:, w//2:w//2+1]
-
-        camera_z = center_rays.mean((0, 1))
-    else:
-        camera_z = track_point - camera_pos
-
-    camera_z = camera_z / np.linalg.norm(camera_z, axis=-1, keepdims=True)
-
-    # We assume that (a) the z-axis is vertical, and that
-    # (b) the camera's horizontal, the x-axis, is orthogonal to the vertical, i.e.,
-    # the camera is in a level position.
-    vertical = np.array((0., 0., 1.))
-
-    camera_x = np.cross(camera_z, vertical)
-    camera_x = camera_x / np.linalg.norm(camera_x, axis=-1, keepdims=True)
-    camera_y = np.cross(camera_z, camera_x)
-
-    camera_matrix = np.stack((camera_x, camera_y, camera_z), -2)
-    translation = -np.einsum('...ij,...j->...i', camera_matrix, camera_pos)
-    camera_matrix = np.concatenate((camera_matrix, np.expand_dims(translation, -1)), -1)
-
-    if fourxfour:
-        filler = np.array([[0., 0., 0., 1.]])
-        camera_matrix = np.concatenate((camera_matrix, filler), 0)
-    return camera_matrix
-
-
-def transform_points(points, transform, translate=True):
-    """ Apply linear transform to a np array of points.
-    Args:
-        points (np array [..., 3]): Points to transform.
-        transform (np array [3, 4] or [4, 4]): Linear map.
-        translate (bool): If false, do not apply translation component of transform.
-    Returns:
-        transformed points (np array [..., 3])
-    """
-    # Append ones or zeros to get homogenous coordinates
-    if translate:
-        constant_term = np.ones_like(points[..., :1])
-    else:
-        constant_term = np.zeros_like(points[..., :1])
-    points = np.concatenate((points, constant_term), axis=-1)
-
-    points = np.einsum('nm,...m->...n', transform, points)
-    return points[..., :3]
-
-def get_camera_rays(c_pos, c_rot, width=640, height=480, focal_length=0.035, sensor_width=0.032,
-                    vertical=None):
-    if vertical is None:
-        vertical = np.array((0., 0., 1.))
-
-    c_dir = c_rot
-
-    img_plane_center = c_pos + c_dir * focal_length
-
-    # The horizontal axis of the camera sensor is horizontal (z=0) and orthogonal to the view axis
-    img_plane_horizontal = np.cross(c_dir, vertical)
-    img_plane_horizontal = img_plane_horizontal / np.linalg.norm(img_plane_horizontal)
-
-    # The vertical axis is orthogonal to both the view axis and the horizontal axis
-    img_plane_vertical = np.cross(c_dir, img_plane_horizontal)
-    img_plane_vertical = img_plane_vertical / np.linalg.norm(img_plane_vertical)
-
-    # Double check that everything is orthogonal
-    def is_small(x, atol=1e-7):
-        return abs(x) < atol
-
-    assert(is_small(np.dot(img_plane_vertical, img_plane_horizontal)))
-    assert(is_small(np.dot(img_plane_vertical, c_dir)))
-    assert(is_small(np.dot(c_dir, img_plane_horizontal)))
-
-    # Sensor height is implied by sensor width and aspect ratio
-    sensor_height = (sensor_width / width) * height
-
-    # Compute pixel boundaries
-    horizontal_offsets = np.linspace(-1, 1, width+1) * sensor_width / 2
-    vertical_offsets = np.linspace(-1, 1, height+1) * sensor_height / 2
-
-    # Compute pixel centers
-    horizontal_offsets = (horizontal_offsets[:-1] + horizontal_offsets[1:]) / 2
-    vertical_offsets = (vertical_offsets[:-1] + vertical_offsets[1:]) / 2
-
-    horizontal_offsets = np.repeat(np.reshape(horizontal_offsets, (1, width)), height, 0)
-    vertical_offsets = np.repeat(np.reshape(vertical_offsets, (height, 1)), width, 1)
-
-
-    horizontal_offsets = (np.reshape(horizontal_offsets, (height, width, 1)) *
-                          np.reshape(img_plane_horizontal, (1, 1, 3)))
-    vertical_offsets = (np.reshape(vertical_offsets, (height, width, 1)) *
-                        np.reshape(img_plane_vertical, (1, 1, 3)))
-
-    image_plane = horizontal_offsets + vertical_offsets
-
-    image_plane = image_plane + np.reshape(img_plane_center, (1, 1, 3))
-    c_pos_exp = np.reshape(c_pos, (1, 1, 3))
-    rays = image_plane - c_pos_exp
-    ray_norms = np.linalg.norm(rays, axis=2, keepdims=True)
-    rays = rays / ray_norms
-    return rays.astype(np.float32)
-
-def downsample(x, num_steps=1):
-    if num_steps is None or num_steps < 1:
-        return x
-    stride = 2**num_steps
-    return x[stride//2::stride, stride//2::stride]
-
-
-class YCBVideo3D(Dataset):
-    def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True,
-                 max_len=None, full_scale=False, shapenet=False, downsample=None):
-        """ Loads the YCB-Video dataset that we have adapted.
-
-        Args:
-            path (str): Path to dataset.
-            mode (str): 'train', 'val', or 'test'.
-            points_per_item (int): Number of target points per scene.
-            max_len (int): Limit to the number of entries in the dataset.
-            canonical_view (bool): Return data in canonical camera coordinates (like in SRT), as opposed
-                to world coordinates.
-            full_scale (bool): Return all available target points, instead of sampling.
-            downsample (int): Downsample height and width of input image by a factor of 2**downsample
-        """
-        self.path = path
-        self.mode = mode
-        self.points_per_item = points_per_item
-        self.max_len = max_len
-        self.canonical = canonical_view
-        self.full_scale = full_scale
-        self.shapenet = shapenet
-        self.downsample = downsample
-
-        self.max_num_entities = 21 # max number of objects in a scene 
-        self.num_views = 3 # TODO : set this number for each scene 
-
-        self.start_idx, self.end_idx = {'train': (0, 70000),
-                                        'val': (70000, 75000),
-                                        'test': (85000, 100000)}[mode]
-
-        self.metadata = np.load(os.path.join(path, 'metadata.npz'))
-        self.metadata = {k: v for k, v in self.metadata.items()}
-
-        self.idxs = np.arange(self.start_idx, self.end_idx)
-
-        dataset_name = 'YCB-Video'
-        print(f'Initialized {dataset_name} {mode} set, {len(self.idxs)} examples')
-        print(self.idxs)
-
-        self.render_kwargs = {
-            'min_dist': 0.035,
-            'max_dist': 35.}
-
-    def __len__(self):
-        if self.max_len is not None:
-            return self.max_len
-        return len(self.idxs) * self.num_views
-
-    def __getitem__(self, idx):
-        scene_idx = idx % len(self.idxs)
-        view_idx = idx // len(self.idxs)
-
-        scene_idx = self.idxs[scene_idx]
-
-        imgs = [np.asarray(imageio.imread(
-            os.path.join(self.path, 'images', f'img_{scene_idx}_{v}.png')))
-            for v in range(self.num_views)]
-
-        imgs = [img[..., :3].astype(np.float32) / 255 for img in imgs]
-
-        mask_idxs = [imageio.imread(os.path.join(self.path, 'masks', f'masks_{scene_idx}_{v}.png'))
-                    for v in range(self.num_views)]
-        masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8)
-        np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
-
-        input_image = downsample(imgs[view_idx], num_steps=self.downsample)
-        input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0)
-
-        all_rays = []
-        # TODO : find a way to get the camera poses
-        all_camera_pos = self.metadata['camera_pos'][:self.num_views].astype(np.float32)
-        all_camera_rot= self.metadata['camera_rot'][:self.num_views].astype(np.float32)
-        for i in range(self.num_views):
-            cur_rays = get_camera_rays(all_camera_pos[i], all_camera_rot[i], noisy=False) # TODO : adapt function
-            all_rays.append(cur_rays)
-        all_rays = np.stack(all_rays, 0).astype(np.float32)
-
-        input_camera_pos = all_camera_pos[view_idx]
-
-        if self.canonical:
-            track_point = np.zeros_like(input_camera_pos)  # All cameras are pointed at the origin
-            canonical_extrinsic = get_extrinsic(input_camera_pos, track_point=track_point) # TODO : adapt function
-            canonical_extrinsic = canonical_extrinsic.astype(np.float32) 
-            all_rays = transform_points(all_rays, canonical_extrinsic, translate=False) # TODO : adapt function
-            all_camera_pos = transform_points(all_camera_pos, canonical_extrinsic)
-            input_camera_pos = all_camera_pos[view_idx]
-
-        input_rays = all_rays[view_idx]
-        input_rays = downsample(input_rays, num_steps=self.downsample)
-        input_rays = np.expand_dims(input_rays, 0)
-
-        input_masks = masks[view_idx]
-        input_masks = downsample(input_masks, num_steps=self.downsample)
-        input_masks = np.expand_dims(input_masks, 0)
-
-        input_camera_pos = np.expand_dims(input_camera_pos, 0)
-
-        all_pixels = np.reshape(np.stack(imgs, 0), (self.num_views * 240 * 320, 3))
-        all_rays = np.reshape(all_rays, (self.num_views * 240 * 320, 3))
-        all_camera_pos = np.tile(np.expand_dims(all_camera_pos, 1), (1, 240 * 320, 1))
-        all_camera_pos = np.reshape(all_camera_pos, (self.num_views * 240 * 320, 3))
-        all_masks = np.reshape(masks, (self.num_views * 240 * 320, self.max_num_entities))
-
-        num_points = all_rays.shape[0]
-
-        if not self.full_scale:
-            # If we have fewer points than we want, sample with replacement
-            replace = num_points < self.points_per_item
-            sampled_idxs = np.random.choice(np.arange(num_points),
-                                            size=(self.points_per_item,),
-                                            replace=replace)
-
-            target_rays = all_rays[sampled_idxs]
-            target_camera_pos = all_camera_pos[sampled_idxs]
-            target_pixels = all_pixels[sampled_idxs]
-            target_masks = all_masks[sampled_idxs]
-        else:
-            target_rays = all_rays
-            target_camera_pos = all_camera_pos
-            target_pixels = all_pixels
-            target_masks = all_masks
-
-        result = {
-            'input_images':         input_images,         # [1, 3, h, w]
-            'input_camera_pos':     input_camera_pos,     # [1, 3]
-            'input_rays':           input_rays,           # [1, h, w, 3]
-            'input_masks':          input_masks,          # [1, h, w, self.max_num_entities]
-            'target_pixels':        target_pixels,        # [p, 3]
-            'target_camera_pos':    target_camera_pos,    # [p, 3]
-            'target_rays':          target_rays,          # [p, 3]
-            'target_masks':         target_masks,         # [p, self.max_num_entities]
-            'sceneid':              idx,                  # int
-        }
-
-        if self.canonical:
-            result['transform'] = canonical_extrinsic     # [3, 4] (optional)
-
-        return result
-
-class YCBVideo2D(Dataset):
-    def __init__(self, path, mode, downsample=None):
-        """ Loads the YCB-Video dataset that we have adapted.
-
-        Args:
-            path (str): Path to dataset.
-            mode (str): 'train', 'val', or 'test'.
-            downsample (int): Downsample height and width of input image by a factor of 2**downsample
-        """
-        self.path = path
-        self.mode = mode
-        self.downsample = downsample
-
-        self.max_num_entities = 21 # max number of objects in a scene 
-
-        # TODO : set the right number here
-
-        dataset_name = 'YCB-Video'
-        print(f'Initialized {dataset_name} {mode} set')
-
-
-    def __len__(self):
-        return len(self.idxs)
-
-    def __getitem__(self, idx):
-
-        imgs = [np.asarray(imageio.imread(
-            os.path.join(self.path, 'images', f'img_{scene_idx}_{v}.png')))
-            for v in range(self.num_views)]
-
-        imgs = [img[..., :3].astype(np.float32) / 255 for img in imgs]
-
-        mask_idxs = [imageio.imread(os.path.join(self.path, 'masks', f'masks_{scene_idx}_{v}.png'))
-                    for v in range(self.num_views)]
-        masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8)
-        np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
-
-        input_image = downsample(imgs[view_idx], num_steps=self.downsample)
-        input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0)
-
-        all_rays = []
-        # TODO : find a way to get the camera poses
-        all_camera_pos = self.metadata['camera_pos'][:self.num_views].astype(np.float32)
-        all_camera_rot= self.metadata['camera_rot'][:self.num_views].astype(np.float32)
-        for i in range(self.num_views):
-            cur_rays = get_camera_rays(all_camera_pos[i], all_camera_rot[i], noisy=False) # TODO : adapt function
-            all_rays.append(cur_rays)
-        all_rays = np.stack(all_rays, 0).astype(np.float32)
-
-        input_camera_pos = all_camera_pos[view_idx]
-
-        if self.canonical:
-            track_point = np.zeros_like(input_camera_pos)  # All cameras are pointed at the origin
-            canonical_extrinsic = get_extrinsic(input_camera_pos, track_point=track_point) # TODO : adapt function
-            canonical_extrinsic = canonical_extrinsic.astype(np.float32) 
-            all_rays = transform_points(all_rays, canonical_extrinsic, translate=False) # TODO : adapt function
-            all_camera_pos = transform_points(all_camera_pos, canonical_extrinsic)
-            input_camera_pos = all_camera_pos[view_idx]
-
-        input_rays = all_rays[view_idx]
-        input_rays = downsample(input_rays, num_steps=self.downsample)
-        input_rays = np.expand_dims(input_rays, 0)
-
-        input_masks = masks[view_idx]
-        input_masks = downsample(input_masks, num_steps=self.downsample)
-        input_masks = np.expand_dims(input_masks, 0)
-
-        input_camera_pos = np.expand_dims(input_camera_pos, 0)
-
-        all_pixels = np.reshape(np.stack(imgs, 0), (self.num_views * 240 * 320, 3))
-        all_rays = np.reshape(all_rays, (self.num_views * 240 * 320, 3))
-        all_camera_pos = np.tile(np.expand_dims(all_camera_pos, 1), (1, 240 * 320, 1))
-        all_camera_pos = np.reshape(all_camera_pos, (self.num_views * 240 * 320, 3))
-        all_masks = np.reshape(masks, (self.num_views * 240 * 320, self.max_num_entities))
-
-        target_camera_pos = all_camera_pos
-        target_pixels = all_pixels
-        target_masks = all_masks
-
-        result = {
-            'input_images':         input_images,         # [1, 3, h, w]
-            'input_camera_pos':     input_camera_pos,     # [1, 3]
-            'input_rays':           input_rays,           # [1, h, w, 3]
-            'input_masks':          input_masks,          # [1, h, w, self.max_num_entities]
-            'target_pixels':        target_pixels,        # [p, 3]
-            'target_camera_pos':    target_camera_pos,    # [p, 3]
-            'target_masks':         target_masks,         # [p, self.max_num_entities]
-            'sceneid':              idx,                  # int
-        }
-
-        if self.canonical:
-            result['transform'] = canonical_extrinsic     # [3, 4] (optional)
-
-        return result
-
-
diff --git a/osrt/utils/visualization_utils.py b/osrt/utils/visualization_utils.py
index 6e36fafe5039d8abf7a76931366649be8a6c5599..f86b365a1f4d113094089bc56b8006c7ebc45370 100644
--- a/osrt/utils/visualization_utils.py
+++ b/osrt/utils/visualization_utils.py
@@ -114,4 +114,5 @@ def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, fo
     if not save_file:
         plt.show()
     else:
+        folder_save if folder_save is not None else "./"
         plt.savefig(f'{folder_save}visualisation_{step}.png', bbox_inches='tight')
diff --git a/quick_test.py b/quick_test.py
index d63417171cf34291ea8a66fbee23435b2a972d00..7d8cf5c16b74bffe57194df2bd48448e1ec389e2 100644
--- a/quick_test.py
+++ b/quick_test.py
@@ -3,18 +3,19 @@ from osrt import data
 from torch.utils.data import DataLoader
 import matplotlib.pyplot as plt
 
-with open("runs/clevr/slot_att/config.yaml", 'r') as f:
+with open("runs/ycb/slot_att/config.yaml", 'r') as f:
     cfg = yaml.load(f, Loader=yaml.CLoader)
 
-
+cfg['data']['path'] = "/home/achapin/Documents/Datasets/ycbv_small"
 train_dataset = data.get_dataset('train', cfg['data'])
 train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0,shuffle=True)
 
+
 for val in train_loader:
     print(f"Shape masks {val['input_masks'].shape}")
     fig, axes = plt.subplots(2, 2)
-    axes[0][0].imshow(val['input_image'][0])
+    axes[0][0].imshow(val['input_images'][0].permute(1, 2, 0))
     axes[0][1].imshow(val['input_masks'][0][:, :, 0])
-    axes[1][0].imshow(val['input_image'][1])
+    axes[1][0].imshow(val['input_images'][1].permute(1, 2, 0))
     axes[1][1].imshow(val['input_masks'][1][:, :, 0])
     plt.show()
\ No newline at end of file
diff --git a/runs/ycb/slot_att/config.yaml b/runs/ycb/slot_att/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ebc5132f030cfbcb7720e5226fdd97a272414880
--- /dev/null
+++ b/runs/ycb/slot_att/config.yaml
@@ -0,0 +1,24 @@
+data:
+  dataset: ycb2d
+model:
+  num_slots: 10
+  iters: 3
+  model_type: sa
+  input_dim: 64
+  slot_dim: 64
+  hidden_dim: 128
+  iters: 3
+training:
+  num_workers: 2 
+  num_gpus: 1
+  batch_size: 8 
+  max_it: 333000000
+  warmup_it: 10000
+  lr_warmup: 5000
+  decay_rate: 0.5
+  decay_it: 100000
+  visualize_every: 5000
+  validate_every: 5000
+  checkpoint_every: 1000
+  backup_every: 25000
+
diff --git a/runs/ycb/slot_att/config_tsa.yaml b/runs/ycb/slot_att/config_tsa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0e2a05a7a6a25fc626b602c57bb6983049b7b7d8
--- /dev/null
+++ b/runs/ycb/slot_att/config_tsa.yaml
@@ -0,0 +1,24 @@
+data:
+  dataset: ycb2d
+model:
+  num_slots: 10
+  iters: 3
+  model_type: tsa
+  input_dim: 64
+  slot_dim: 64
+  hidden_dim: 128
+  iters: 3
+training:
+  num_workers: 2 
+  num_gpus: 1
+  batch_size: 8
+  max_it: 333000000
+  warmup_it: 10000
+  lr_warmup: 5000
+  decay_rate: 0.5
+  decay_it: 100000
+  visualize_every: 5000
+  validate_every: 5000
+  checkpoint_every: 1000
+  backup_every: 25000
+
diff --git a/train_sa.py b/train_sa.py
index c1e9daaac966da48fd4cd98ca77e7a3c06e316e3..5963cf580f6422d0809ccbe2f465bc14e7be7161 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -44,7 +44,6 @@ def main():
     batch_size = cfg["training"]["batch_size"]
     num_gpus = cfg["training"]["num_gpus"]
     num_slots = cfg["model"]["num_slots"]
-    num_iterations = cfg["model"]["iters"]
     num_train_steps = cfg["training"]["max_it"]
     num_workers = cfg["training"]["num_workers"]
     resolution = (128, 128)
@@ -84,7 +83,7 @@ def main():
                          devices=num_gpus, 
                          profiler="simple", 
                          default_root_dir="./logs", 
-                         logger=WandbLogger(project="slot-att") if args.wandb else None,
+                         logger=WandbLogger(project="slot-att", offline=True) if args.wandb else None,
                          strategy="ddp" if num_gpus > 1 else "auto", 
                          callbacks=[checkpoint_callback, early_stopping],
                          log_every_n_steps=100,