Skip to content
Snippets Groups Projects
ycbv.py 15.16 KiB
import numpy as np
import imageio
import yaml
from torch.utils.data import Dataset

import os
from os import listdir
from os.path import isfile, join
from .core import downsample, crop_center
import torch
import torch.nn.functional as F

def get_extrinsic(camera_pos, rays=None, track_point=None, fourxfour=True):
    """ Returns extrinsic matrix mapping world to camera coordinates.
    Args:
        camera_pos (np array [3]): Camera position.
        track_point (np array [3]): Point on which the camera is fixated.
        rays (np array [h, w, 3]): Rays eminating from the camera. Used to determine track_point
            if it's not given.
        fourxfour (bool): If true, a 4x4 matrix for homogeneous 3D coordinates is returned.
            Otherwise, a 3x4 matrix is returned.
    Returns:
        extrinsic camera matrix (np array [4, 4] or [3, 4])
    """
    if track_point is None:
        h, w, _ = rays.shape
        if h % 2 == 0:
            center_rays = rays[h//2 - 1:h//2 + 1]
        else:
            center_rays = rays[h//2:h//2+1]

        if w % 2 == 0:
            center_rays = rays[:, w//2 - 1:w//2 + 1]
        else:
            center_rays = rays[:, w//2:w//2+1]

        camera_z = center_rays.mean((0, 1))
    else:
        camera_z = track_point - camera_pos

    camera_z = camera_z / np.linalg.norm(camera_z, axis=-1, keepdims=True)

    # We assume that (a) the z-axis is vertical, and that
    # (b) the camera's horizontal, the x-axis, is orthogonal to the vertical, i.e.,
    # the camera is in a level position.
    vertical = np.array((0., 0., 1.))

    camera_x = np.cross(camera_z, vertical)
    camera_x = camera_x / np.linalg.norm(camera_x, axis=-1, keepdims=True)
    camera_y = np.cross(camera_z, camera_x)

    camera_matrix = np.stack((camera_x, camera_y, camera_z), -2)
    translation = -np.einsum('...ij,...j->...i', camera_matrix, camera_pos)
    camera_matrix = np.concatenate((camera_matrix, np.expand_dims(translation, -1)), -1)

    if fourxfour:
        filler = np.array([[0., 0., 0., 1.]])
        camera_matrix = np.concatenate((camera_matrix, filler), 0)
    return camera_matrix


def transform_points(points, transform, translate=True):
    """ Apply linear transform to a np array of points.
    Args:
        points (np array [..., 3]): Points to transform.
        transform (np array [3, 4] or [4, 4]): Linear map.
        translate (bool): If false, do not apply translation component of transform.
    Returns:
        transformed points (np array [..., 3])
    """
    # Append ones or zeros to get homogenous coordinates
    if translate:
        constant_term = np.ones_like(points[..., :1])
    else:
        constant_term = np.zeros_like(points[..., :1])
    points = np.concatenate((points, constant_term), axis=-1)

    points = np.einsum('nm,...m->...n', transform, points)
    return points[..., :3]

def get_camera_rays(c_pos, c_rot, width=640, height=480, focal_length=0.035, sensor_width=0.032,
                    vertical=None):
    if vertical is None:
        vertical = np.array((0., 0., 1.))

    c_dir = c_rot

    img_plane_center = c_pos + c_dir * focal_length

    # The horizontal axis of the camera sensor is horizontal (z=0) and orthogonal to the view axis
    img_plane_horizontal = np.cross(c_dir, vertical)
    img_plane_horizontal = img_plane_horizontal / np.linalg.norm(img_plane_horizontal)

    # The vertical axis is orthogonal to both the view axis and the horizontal axis
    img_plane_vertical = np.cross(c_dir, img_plane_horizontal)
    img_plane_vertical = img_plane_vertical / np.linalg.norm(img_plane_vertical)

    # Double check that everything is orthogonal
    def is_small(x, atol=1e-7):
        return abs(x) < atol

    assert(is_small(np.dot(img_plane_vertical, img_plane_horizontal)))
    assert(is_small(np.dot(img_plane_vertical, c_dir)))
    assert(is_small(np.dot(c_dir, img_plane_horizontal)))

    # Sensor height is implied by sensor width and aspect ratio
    sensor_height = (sensor_width / width) * height

    # Compute pixel boundaries
    horizontal_offsets = np.linspace(-1, 1, width+1) * sensor_width / 2
    vertical_offsets = np.linspace(-1, 1, height+1) * sensor_height / 2

    # Compute pixel centers
    horizontal_offsets = (horizontal_offsets[:-1] + horizontal_offsets[1:]) / 2
    vertical_offsets = (vertical_offsets[:-1] + vertical_offsets[1:]) / 2

    horizontal_offsets = np.repeat(np.reshape(horizontal_offsets, (1, width)), height, 0)
    vertical_offsets = np.repeat(np.reshape(vertical_offsets, (height, 1)), width, 1)


    horizontal_offsets = (np.reshape(horizontal_offsets, (height, width, 1)) *
                          np.reshape(img_plane_horizontal, (1, 1, 3)))
    vertical_offsets = (np.reshape(vertical_offsets, (height, width, 1)) *
                        np.reshape(img_plane_vertical, (1, 1, 3)))

    image_plane = horizontal_offsets + vertical_offsets

    image_plane = image_plane + np.reshape(img_plane_center, (1, 1, 3))
    c_pos_exp = np.reshape(c_pos, (1, 1, 3))
    rays = image_plane - c_pos_exp
    ray_norms = np.linalg.norm(rays, axis=2, keepdims=True)
    rays = rays / ray_norms
    return rays.astype(np.float32)

def extract_images_path(global_path, mode, images, type_im="rgb"):
    scenes = [f for f in listdir(global_path)]
    for scene in scenes:
        path = global_path + scene + f"/{type_im}/"
        temp = np.array(list([join(path, f) for f in listdir(path) if isfile(join(path, f))]))
        temp.sort()
        cut = int(len(temp)*0.7)
        if mode == "train":
            temp = temp[:cut]
        else:
            temp = temp[cut:]
        images = np.concatenate((images, temp))
    return images


class YCBVideo3D(Dataset):
    def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True,
                 max_len=None, full_scale=False, shapenet=False, downsample=None):
        """ Loads the YCB-Video dataset that we have adapted.

        Args:
            path (str): Path to dataset.
            mode (str): 'train', 'val', or 'test'.
            points_per_item (int): Number of target points per scene.
            max_len (int): Limit to the number of entries in the dataset.
            canonical_view (bool): Return data in canonical camera coordinates (like in SRT), as opposed
                to world coordinates.
            full_scale (bool): Return all available target points, instead of sampling.
            downsample (int): Downsample height and width of input image by a factor of 2**downsample
        """
        self.path = path
        self.mode = mode
        self.points_per_item = points_per_item
        self.max_len = max_len
        self.canonical = canonical_view
        self.full_scale = full_scale
        self.shapenet = shapenet
        self.downsample = downsample

        self.max_num_entities = 21 # max number of objects in a scene 
        self.num_views = 3 # TODO : set this number for each scene 

        self.start_idx, self.end_idx = {'train': (0, 70000),
                                        'val': (70000, 75000),
                                        'test': (85000, 100000)}[mode]

        self.metadata = np.load(os.path.join(path, 'metadata.npz'))
        self.metadata = {k: v for k, v in self.metadata.items()}

        self.idxs = np.arange(self.start_idx, self.end_idx)

        dataset_name = 'YCB-Video'
        print(f'Initialized {dataset_name} {mode} set, {len(self.idxs)} examples')
        print(self.idxs)

        self.render_kwargs = {
            'min_dist': 0.035,
            'max_dist': 35.}

    def __len__(self):
        if self.max_len is not None:
            return self.max_len
        return len(self.idxs) * self.num_views

    def __getitem__(self, idx):
        scene_idx = idx % len(self.idxs)
        view_idx = idx // len(self.idxs)

        scene_idx = self.idxs[scene_idx]

        imgs = [np.asarray(imageio.imread(
            os.path.join(self.path, 'images', f'img_{scene_idx}_{v}.png')))
            for v in range(self.num_views)]

        imgs = [img[..., :3].astype(np.float32) / 255 for img in imgs]

        mask_idxs = [imageio.imread(os.path.join(self.path, 'masks', f'masks_{scene_idx}_{v}.png'))
                    for v in range(self.num_views)]
        masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8)
        np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)

        input_image = downsample(imgs[view_idx], num_steps=self.downsample)
        input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0)

        all_rays = []
        # TODO : find a way to get the camera poses
        all_camera_pos = self.metadata['camera_pos'][:self.num_views].astype(np.float32)
        all_camera_rot= self.metadata['camera_rot'][:self.num_views].astype(np.float32)
        for i in range(self.num_views):
            cur_rays = get_camera_rays(all_camera_pos[i], all_camera_rot[i], noisy=False) # TODO : adapt function
            all_rays.append(cur_rays)
        all_rays = np.stack(all_rays, 0).astype(np.float32)

        input_camera_pos = all_camera_pos[view_idx]

        if self.canonical:
            track_point = np.zeros_like(input_camera_pos)  # All cameras are pointed at the origin
            canonical_extrinsic = get_extrinsic(input_camera_pos, track_point=track_point) # TODO : adapt function
            canonical_extrinsic = canonical_extrinsic.astype(np.float32) 
            all_rays = transform_points(all_rays, canonical_extrinsic, translate=False) # TODO : adapt function
            all_camera_pos = transform_points(all_camera_pos, canonical_extrinsic)
            input_camera_pos = all_camera_pos[view_idx]

        input_rays = all_rays[view_idx]
        input_rays = downsample(input_rays, num_steps=self.downsample)
        input_rays = np.expand_dims(input_rays, 0)

        input_masks = masks[view_idx]
        input_masks = downsample(input_masks, num_steps=self.downsample)
        input_masks = np.expand_dims(input_masks, 0)

        input_camera_pos = np.expand_dims(input_camera_pos, 0)

        all_pixels = np.reshape(np.stack(imgs, 0), (self.num_views * 240 * 320, 3))
        all_rays = np.reshape(all_rays, (self.num_views * 240 * 320, 3))
        all_camera_pos = np.tile(np.expand_dims(all_camera_pos, 1), (1, 240 * 320, 1))
        all_camera_pos = np.reshape(all_camera_pos, (self.num_views * 240 * 320, 3))
        all_masks = np.reshape(masks, (self.num_views * 240 * 320, self.max_num_entities))

        num_points = all_rays.shape[0]

        if not self.full_scale:
            # If we have fewer points than we want, sample with replacement
            replace = num_points < self.points_per_item
            sampled_idxs = np.random.choice(np.arange(num_points),
                                            size=(self.points_per_item,),
                                            replace=replace)

            target_rays = all_rays[sampled_idxs]
            target_camera_pos = all_camera_pos[sampled_idxs]
            target_pixels = all_pixels[sampled_idxs]
            target_masks = all_masks[sampled_idxs]
        else:
            target_rays = all_rays
            target_camera_pos = all_camera_pos
            target_pixels = all_pixels
            target_masks = all_masks

        result = {
            'input_images':         input_images,         # [1, 3, h, w]
            'input_camera_pos':     input_camera_pos,     # [1, 3]
            'input_rays':           input_rays,           # [1, h, w, 3]
            'input_masks':          input_masks,          # [1, h, w, self.max_num_entities]
            'target_pixels':        target_pixels,        # [p, 3]
            'target_camera_pos':    target_camera_pos,    # [p, 3]
            'target_rays':          target_rays,          # [p, 3]
            'target_masks':         target_masks,         # [p, self.max_num_entities]
            'sceneid':              idx,                  # int
        }

        if self.canonical:
            result['transform'] = canonical_extrinsic     # [3, 4] (optional)

        return result

class YCBVideo2D(Dataset):
    def __init__(self, path, mode, max_objects=6):
        """ Loads the YCB dataset in the right format

        Args:
            path (str): Path to dataset.
            mode (str): 'train', 'val', or 'test'.
            full_scale (bool): Return all available target points, instead of sampling.
            max_objects (int): Load only scenes with at most this many objects.
        """
        self.path = path
        print(f"Get path {path}")
        self.mode = mode
        self.max_objects = max_objects

        self.max_num_entities = 22
        self.rescale = 128

        """self.metadata = np.load(os.path.join(self.path, 'metadata.npz'))
        self.metadata = {k: v for k, v in self.metadata.items()}

        num_objs = (self.metadata['shape'][self.start_idx:self.end_idx] > 0).sum(1)

        self.idxs = np.arange(self.start_idx, self.end_idx)[num_objs <= max_objects]"""

        self.images = np.empty(shape=(0,), dtype=np.str_)
        self.masks = np.empty(shape=(0,), dtype=np.str_)
        if mode == "test":
            self.path += "/test/"
            scenes = [f for f in listdir(self.path)]
            for scene in scenes:
                path = self.path + scene 
                temp = np.array([f for f in listdir(path) if isfile(join(path, f))])
                self.images = np.concatenate(self.images, temp)
        else:
            path_real = self.path + "/train_real/"
            path_synth = self.path + "/train_synth/"
            self.images = extract_images_path(path_real, mode, self.images)
            self.images = extract_images_path(path_synth, mode, self.images)
            self.masks = extract_images_path(path_real, mode, self.masks, "masks")
            self.masks = extract_images_path(path_synth, mode, self.masks, "masks")
        dataset_name = 'YCB'

        print(f"Load dataset {dataset_name} in mode {self.mode}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx, noisy=True):
        """scene_idx = idx % len(self.idxs)
        scene_idx = self.idxs[scene_idx]"""

        img_path = self.images[idx]
        img = np.asarray(imageio.imread(img_path))
        img = img[..., :3].astype(np.float32) / 255

        input_image = crop_center(img, 440) 
        input_image = F.interpolate(torch.tensor(input_image).permute(2, 0, 1).unsqueeze(0), size=self.rescale)
        input_image = input_image.squeeze(0)

        mask_path = self.masks[idx]
        mask_idxs = imageio.imread(mask_path)

        masks = np.zeros((480, 640, self.max_num_entities), dtype=np.uint8)

        np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)

        input_masks = crop_center(torch.tensor(masks), 440)
        input_masks = F.interpolate(input_masks.permute(2, 0, 1).unsqueeze(0), size=128)
        input_masks = input_masks.squeeze(0).permute(1, 2, 0)
        target_masks = np.reshape(input_masks, (self.rescale*self.rescale, self.max_num_entities))

        result = {
            'input_images':          input_image,         # [3, h, w]
            'input_masks':          input_masks,         # [h, w, self.max_num_entities]
            'target_masks':         target_masks,        # [h*w, self.max_num_entities]
            'sceneid':              idx,                 # int
        }

        return result