diff --git a/.gitignore b/.gitignore
index 5512be3fde32442485f31e5ccb6163325785758b..eb2575eea844d8142400315a9f4a7852aa02b097 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,5 @@
 outputs
 logs/*
 logs
-data
 results
 *__pycache__
diff --git a/osrt/data/ycbv.py b/osrt/data/ycbv.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9865ce9e38ef8268c426d269c6bfdce2a04ec9e
--- /dev/null
+++ b/osrt/data/ycbv.py
@@ -0,0 +1,371 @@
+import numpy as np
+import imageio
+import yaml
+from torch.utils.data import Dataset
+
+import os
+from os import listdir
+from os.path import isfile, join
+from .core import downsample, crop_center
+import torch
+import torch.nn.functional as F
+
+def get_extrinsic(camera_pos, rays=None, track_point=None, fourxfour=True):
+    """ Returns extrinsic matrix mapping world to camera coordinates.
+    Args:
+        camera_pos (np array [3]): Camera position.
+        track_point (np array [3]): Point on which the camera is fixated.
+        rays (np array [h, w, 3]): Rays eminating from the camera. Used to determine track_point
+            if it's not given.
+        fourxfour (bool): If true, a 4x4 matrix for homogeneous 3D coordinates is returned.
+            Otherwise, a 3x4 matrix is returned.
+    Returns:
+        extrinsic camera matrix (np array [4, 4] or [3, 4])
+    """
+    if track_point is None:
+        h, w, _ = rays.shape
+        if h % 2 == 0:
+            center_rays = rays[h//2 - 1:h//2 + 1]
+        else:
+            center_rays = rays[h//2:h//2+1]
+
+        if w % 2 == 0:
+            center_rays = rays[:, w//2 - 1:w//2 + 1]
+        else:
+            center_rays = rays[:, w//2:w//2+1]
+
+        camera_z = center_rays.mean((0, 1))
+    else:
+        camera_z = track_point - camera_pos
+
+    camera_z = camera_z / np.linalg.norm(camera_z, axis=-1, keepdims=True)
+
+    # We assume that (a) the z-axis is vertical, and that
+    # (b) the camera's horizontal, the x-axis, is orthogonal to the vertical, i.e.,
+    # the camera is in a level position.
+    vertical = np.array((0., 0., 1.))
+
+    camera_x = np.cross(camera_z, vertical)
+    camera_x = camera_x / np.linalg.norm(camera_x, axis=-1, keepdims=True)
+    camera_y = np.cross(camera_z, camera_x)
+
+    camera_matrix = np.stack((camera_x, camera_y, camera_z), -2)
+    translation = -np.einsum('...ij,...j->...i', camera_matrix, camera_pos)
+    camera_matrix = np.concatenate((camera_matrix, np.expand_dims(translation, -1)), -1)
+
+    if fourxfour:
+        filler = np.array([[0., 0., 0., 1.]])
+        camera_matrix = np.concatenate((camera_matrix, filler), 0)
+    return camera_matrix
+
+
+def transform_points(points, transform, translate=True):
+    """ Apply linear transform to a np array of points.
+    Args:
+        points (np array [..., 3]): Points to transform.
+        transform (np array [3, 4] or [4, 4]): Linear map.
+        translate (bool): If false, do not apply translation component of transform.
+    Returns:
+        transformed points (np array [..., 3])
+    """
+    # Append ones or zeros to get homogenous coordinates
+    if translate:
+        constant_term = np.ones_like(points[..., :1])
+    else:
+        constant_term = np.zeros_like(points[..., :1])
+    points = np.concatenate((points, constant_term), axis=-1)
+
+    points = np.einsum('nm,...m->...n', transform, points)
+    return points[..., :3]
+
+def get_camera_rays(c_pos, c_rot, width=640, height=480, focal_length=0.035, sensor_width=0.032,
+                    vertical=None):
+    if vertical is None:
+        vertical = np.array((0., 0., 1.))
+
+    c_dir = c_rot
+
+    img_plane_center = c_pos + c_dir * focal_length
+
+    # The horizontal axis of the camera sensor is horizontal (z=0) and orthogonal to the view axis
+    img_plane_horizontal = np.cross(c_dir, vertical)
+    img_plane_horizontal = img_plane_horizontal / np.linalg.norm(img_plane_horizontal)
+
+    # The vertical axis is orthogonal to both the view axis and the horizontal axis
+    img_plane_vertical = np.cross(c_dir, img_plane_horizontal)
+    img_plane_vertical = img_plane_vertical / np.linalg.norm(img_plane_vertical)
+
+    # Double check that everything is orthogonal
+    def is_small(x, atol=1e-7):
+        return abs(x) < atol
+
+    assert(is_small(np.dot(img_plane_vertical, img_plane_horizontal)))
+    assert(is_small(np.dot(img_plane_vertical, c_dir)))
+    assert(is_small(np.dot(c_dir, img_plane_horizontal)))
+
+    # Sensor height is implied by sensor width and aspect ratio
+    sensor_height = (sensor_width / width) * height
+
+    # Compute pixel boundaries
+    horizontal_offsets = np.linspace(-1, 1, width+1) * sensor_width / 2
+    vertical_offsets = np.linspace(-1, 1, height+1) * sensor_height / 2
+
+    # Compute pixel centers
+    horizontal_offsets = (horizontal_offsets[:-1] + horizontal_offsets[1:]) / 2
+    vertical_offsets = (vertical_offsets[:-1] + vertical_offsets[1:]) / 2
+
+    horizontal_offsets = np.repeat(np.reshape(horizontal_offsets, (1, width)), height, 0)
+    vertical_offsets = np.repeat(np.reshape(vertical_offsets, (height, 1)), width, 1)
+
+
+    horizontal_offsets = (np.reshape(horizontal_offsets, (height, width, 1)) *
+                          np.reshape(img_plane_horizontal, (1, 1, 3)))
+    vertical_offsets = (np.reshape(vertical_offsets, (height, width, 1)) *
+                        np.reshape(img_plane_vertical, (1, 1, 3)))
+
+    image_plane = horizontal_offsets + vertical_offsets
+
+    image_plane = image_plane + np.reshape(img_plane_center, (1, 1, 3))
+    c_pos_exp = np.reshape(c_pos, (1, 1, 3))
+    rays = image_plane - c_pos_exp
+    ray_norms = np.linalg.norm(rays, axis=2, keepdims=True)
+    rays = rays / ray_norms
+    return rays.astype(np.float32)
+
+def extract_images_path(global_path, mode, images, type_im="rgb"):
+    scenes = [f for f in listdir(global_path)]
+    for scene in scenes:
+        path = global_path + scene + f"/{type_im}/"
+        temp = np.array(list([join(path, f) for f in listdir(path) if isfile(join(path, f))]))
+        temp.sort()
+        cut = int(len(temp)*0.7)
+        if mode == "train":
+            temp = temp[:cut]
+        else:
+            temp = temp[cut:]
+        images = np.concatenate((images, temp))
+    return images
+
+
+class YCBVideo3D(Dataset):
+    def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True,
+                 max_len=None, full_scale=False, shapenet=False, downsample=None):
+        """ Loads the YCB-Video dataset that we have adapted.
+
+        Args:
+            path (str): Path to dataset.
+            mode (str): 'train', 'val', or 'test'.
+            points_per_item (int): Number of target points per scene.
+            max_len (int): Limit to the number of entries in the dataset.
+            canonical_view (bool): Return data in canonical camera coordinates (like in SRT), as opposed
+                to world coordinates.
+            full_scale (bool): Return all available target points, instead of sampling.
+            downsample (int): Downsample height and width of input image by a factor of 2**downsample
+        """
+        self.path = path
+        self.mode = mode
+        self.points_per_item = points_per_item
+        self.max_len = max_len
+        self.canonical = canonical_view
+        self.full_scale = full_scale
+        self.shapenet = shapenet
+        self.downsample = downsample
+
+        self.max_num_entities = 21 # max number of objects in a scene 
+        self.num_views = 3 # TODO : set this number for each scene 
+
+        self.start_idx, self.end_idx = {'train': (0, 70000),
+                                        'val': (70000, 75000),
+                                        'test': (85000, 100000)}[mode]
+
+        self.metadata = np.load(os.path.join(path, 'metadata.npz'))
+        self.metadata = {k: v for k, v in self.metadata.items()}
+
+        self.idxs = np.arange(self.start_idx, self.end_idx)
+
+        dataset_name = 'YCB-Video'
+        print(f'Initialized {dataset_name} {mode} set, {len(self.idxs)} examples')
+        print(self.idxs)
+
+        self.render_kwargs = {
+            'min_dist': 0.035,
+            'max_dist': 35.}
+
+    def __len__(self):
+        if self.max_len is not None:
+            return self.max_len
+        return len(self.idxs) * self.num_views
+
+    def __getitem__(self, idx):
+        scene_idx = idx % len(self.idxs)
+        view_idx = idx // len(self.idxs)
+
+        scene_idx = self.idxs[scene_idx]
+
+        imgs = [np.asarray(imageio.imread(
+            os.path.join(self.path, 'images', f'img_{scene_idx}_{v}.png')))
+            for v in range(self.num_views)]
+
+        imgs = [img[..., :3].astype(np.float32) / 255 for img in imgs]
+
+        mask_idxs = [imageio.imread(os.path.join(self.path, 'masks', f'masks_{scene_idx}_{v}.png'))
+                    for v in range(self.num_views)]
+        masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8)
+        np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
+
+        input_image = downsample(imgs[view_idx], num_steps=self.downsample)
+        input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0)
+
+        all_rays = []
+        # TODO : find a way to get the camera poses
+        all_camera_pos = self.metadata['camera_pos'][:self.num_views].astype(np.float32)
+        all_camera_rot= self.metadata['camera_rot'][:self.num_views].astype(np.float32)
+        for i in range(self.num_views):
+            cur_rays = get_camera_rays(all_camera_pos[i], all_camera_rot[i], noisy=False) # TODO : adapt function
+            all_rays.append(cur_rays)
+        all_rays = np.stack(all_rays, 0).astype(np.float32)
+
+        input_camera_pos = all_camera_pos[view_idx]
+
+        if self.canonical:
+            track_point = np.zeros_like(input_camera_pos)  # All cameras are pointed at the origin
+            canonical_extrinsic = get_extrinsic(input_camera_pos, track_point=track_point) # TODO : adapt function
+            canonical_extrinsic = canonical_extrinsic.astype(np.float32) 
+            all_rays = transform_points(all_rays, canonical_extrinsic, translate=False) # TODO : adapt function
+            all_camera_pos = transform_points(all_camera_pos, canonical_extrinsic)
+            input_camera_pos = all_camera_pos[view_idx]
+
+        input_rays = all_rays[view_idx]
+        input_rays = downsample(input_rays, num_steps=self.downsample)
+        input_rays = np.expand_dims(input_rays, 0)
+
+        input_masks = masks[view_idx]
+        input_masks = downsample(input_masks, num_steps=self.downsample)
+        input_masks = np.expand_dims(input_masks, 0)
+
+        input_camera_pos = np.expand_dims(input_camera_pos, 0)
+
+        all_pixels = np.reshape(np.stack(imgs, 0), (self.num_views * 240 * 320, 3))
+        all_rays = np.reshape(all_rays, (self.num_views * 240 * 320, 3))
+        all_camera_pos = np.tile(np.expand_dims(all_camera_pos, 1), (1, 240 * 320, 1))
+        all_camera_pos = np.reshape(all_camera_pos, (self.num_views * 240 * 320, 3))
+        all_masks = np.reshape(masks, (self.num_views * 240 * 320, self.max_num_entities))
+
+        num_points = all_rays.shape[0]
+
+        if not self.full_scale:
+            # If we have fewer points than we want, sample with replacement
+            replace = num_points < self.points_per_item
+            sampled_idxs = np.random.choice(np.arange(num_points),
+                                            size=(self.points_per_item,),
+                                            replace=replace)
+
+            target_rays = all_rays[sampled_idxs]
+            target_camera_pos = all_camera_pos[sampled_idxs]
+            target_pixels = all_pixels[sampled_idxs]
+            target_masks = all_masks[sampled_idxs]
+        else:
+            target_rays = all_rays
+            target_camera_pos = all_camera_pos
+            target_pixels = all_pixels
+            target_masks = all_masks
+
+        result = {
+            'input_images':         input_images,         # [1, 3, h, w]
+            'input_camera_pos':     input_camera_pos,     # [1, 3]
+            'input_rays':           input_rays,           # [1, h, w, 3]
+            'input_masks':          input_masks,          # [1, h, w, self.max_num_entities]
+            'target_pixels':        target_pixels,        # [p, 3]
+            'target_camera_pos':    target_camera_pos,    # [p, 3]
+            'target_rays':          target_rays,          # [p, 3]
+            'target_masks':         target_masks,         # [p, self.max_num_entities]
+            'sceneid':              idx,                  # int
+        }
+
+        if self.canonical:
+            result['transform'] = canonical_extrinsic     # [3, 4] (optional)
+
+        return result
+
+class YCBVideo2D(Dataset):
+    def __init__(self, path, mode, max_objects=6):
+        """ Loads the YCB dataset in the right format
+
+        Args:
+            path (str): Path to dataset.
+            mode (str): 'train', 'val', or 'test'.
+            full_scale (bool): Return all available target points, instead of sampling.
+            max_objects (int): Load only scenes with at most this many objects.
+        """
+        self.path = path
+        print(f"Get path {path}")
+        self.mode = mode
+        self.max_objects = max_objects
+
+        self.max_num_entities = 22
+        self.rescale = 128
+
+        """self.metadata = np.load(os.path.join(self.path, 'metadata.npz'))
+        self.metadata = {k: v for k, v in self.metadata.items()}
+
+        num_objs = (self.metadata['shape'][self.start_idx:self.end_idx] > 0).sum(1)
+
+        self.idxs = np.arange(self.start_idx, self.end_idx)[num_objs <= max_objects]"""
+
+        self.images = np.empty(shape=(0,), dtype=np.str_)
+        self.masks = np.empty(shape=(0,), dtype=np.str_)
+        if mode == "test":
+            self.path += "/test/"
+            scenes = [f for f in listdir(self.path)]
+            for scene in scenes:
+                path = self.path + scene 
+                temp = np.array([f for f in listdir(path) if isfile(join(path, f))])
+                self.images = np.concatenate(self.images, temp)
+        else:
+            path_real = self.path + "/train_real/"
+            path_synth = self.path + "/train_synth/"
+            self.images = extract_images_path(path_real, mode, self.images)
+            self.images = extract_images_path(path_synth, mode, self.images)
+            self.masks = extract_images_path(path_real, mode, self.masks, "masks")
+            self.masks = extract_images_path(path_synth, mode, self.masks, "masks")
+        dataset_name = 'YCB'
+
+        print(f"Load dataset {dataset_name} in mode {self.mode}")
+
+    def __len__(self):
+        return len(self.images)
+
+    def __getitem__(self, idx, noisy=True):
+        """scene_idx = idx % len(self.idxs)
+        scene_idx = self.idxs[scene_idx]"""
+
+        img_path = self.images[idx]
+        img = np.asarray(imageio.imread(img_path))
+        img = img[..., :3].astype(np.float32) / 255
+
+        input_image = crop_center(img, 440) 
+        input_image = F.interpolate(torch.tensor(input_image).permute(2, 0, 1).unsqueeze(0), size=self.rescale)
+        input_image = input_image.squeeze(0)
+
+        mask_path = self.masks[idx]
+        mask_idxs = imageio.imread(mask_path)
+
+        masks = np.zeros((480, 640, self.max_num_entities), dtype=np.uint8)
+
+        np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
+
+        input_masks = crop_center(torch.tensor(masks), 440)
+        input_masks = F.interpolate(input_masks.permute(2, 0, 1).unsqueeze(0), size=128)
+        input_masks = input_masks.squeeze(0).permute(1, 2, 0)
+        target_masks = np.reshape(input_masks, (self.rescale*self.rescale, self.max_num_entities))
+
+        result = {
+            'input_images':          input_image,         # [3, h, w]
+            'input_masks':          input_masks,         # [h, w, self.max_num_entities]
+            'target_masks':         target_masks,        # [h*w, self.max_num_entities]
+            'sceneid':              idx,                 # int
+        }
+
+        return result
+
+