import torch
import torch.optim as optim
import numpy as np
import imageio

import os, sys, argparse, math
import yaml, json
from tqdm import tqdm

from osrt.data import get_dataset
from osrt.checkpoint import Checkpoint
from osrt.utils.visualize import visualize_2d_cluster, get_clustering_colors
from osrt.utils.nerf import rotate_around_z_axis_torch, get_camera_rays, transform_points_torch, get_extrinsic_torch
from osrt.model import OSRT
from osrt.trainer import SRTTrainer

from compile_video import compile_video_render, compile_video_plot


def get_camera_rays_render(camera_pos, **kwargs):
    rays = get_camera_rays(camera_pos[0], **kwargs)
    return np.expand_dims(rays, 0)

def lerp(x, y, t):
    return x + (y-x) * t

def easeout(t):
    return -0.5 * t**2 + 1.5 * t

def apply_fade(t, t_fade=0.2):
    v_max = 1. / (1. - t_fade)
    acc = v_max / t_fade
    if t <= t_fade:
        return 0.5 * acc * t**2
    pos_past_fade = 0.5 * acc * t_fade**2
    if t <= 1. - t_fade:
        return pos_past_fade + v_max * (t - t_fade)
    else:
        return 1. - 0.5 * acc * (t - 1.)**2

def get_camera_closeup(camera_pos, rays, t, zoomout=1., closeup=0.2, z_closeup=0.1, lookup=3.):
    orig_camera_pos = camera_pos[0] * zoomout
    orig_track_point = torch.zeros_like(orig_camera_pos)
    orig_ext = get_extrinsic_torch(orig_camera_pos, track_point=orig_track_point, fourxfour=True)

    final_camera_pos = closeup * orig_camera_pos
    final_camera_pos[2] = z_closeup * orig_camera_pos[2]
    final_track_point = orig_camera_pos + (orig_track_point - orig_camera_pos) * lookup
    final_track_point[2] = 0.

    cur_camera_pos = lerp(orig_camera_pos, final_camera_pos, t)
    cur_camera_pos[2] = lerp(orig_camera_pos[2], final_camera_pos[2], easeout(t))
    cur_track_point = lerp(orig_track_point, final_track_point, t)

    new_ext = get_extrinsic_torch(cur_camera_pos, track_point=cur_track_point, fourxfour=True)

    cur_rays = transform_points_torch(rays, torch.inverse(new_ext) @ orig_ext, translate=False)
    return cur_camera_pos.unsqueeze(0), cur_rays


def rotate_camera(camera_pos, rays, t):
    theta = math.pi * 2 * t
    camera_pos = rotate_around_z_axis_torch(camera_pos, theta)
    rays = rotate_around_z_axis_torch(rays, theta)

    return camera_pos, rays


def render3d(trainer, render_path, z, camera_pos, motion, transform=None, resolution=None, **render_kwargs):
    if transform is not None:  # Project camera into world space before applying motion transformations
        inv_transform = torch.inverse(transform)
        camera_pos = transform_points_torch(camera_pos, inv_transform)

    camera_pos_np = camera_pos.cpu().numpy()
    rays = torch.Tensor(get_camera_rays_render(camera_pos_np, **resolution)).to(camera_pos)

    for frame in tqdm(range(args.num_frames)):
        t = frame / args.num_frames
        if args.fade:
            t = apply_fade(t)
        if motion == 'rotate':  # Rotate camera around scene, tracking scene's center
            cur_camera_pos, cur_rays = rotate_camera(camera_pos, rays, t)
        elif motion == 'zoom':  # Stationary camera and track point, zoom in by reducing sensor width
            sensor_max = 0.032
            sensor_min = sensor_max / 5
            sensor_cur = lerp(sensor_max, sensor_min, frame / args.num_frames)
            cur_rays = get_camera_rays_render(camera_pos_np, sensor_width=sensor_cur, **resolution)
            cur_rays = torch.Tensor(cur_rays).float().cuda()
            cur_camera_pos = camera_pos
        elif motion == 'closeup':  # Move camera towards center of the scene, pan up slightly
            cur_camera_pos, cur_rays = get_camera_closeup(camera_pos, rays, t)
        elif motion == 'rotate_and_closeup':  # Rotate while moving in for a slight closeup
            t_closeup = ((-math.cos(t * math.pi * 2) + 1) * 0.5) * 0.5
            cur_camera_pos, cur_rays = get_camera_closeup(camera_pos, rays, t_closeup, lookup=1.5)
            cur_camera_pos, cur_rays = rotate_camera(cur_camera_pos, cur_rays, t)
        elif motion == 'eyeroll':  # Stationary camera, tracking circle around the scene
            theta = -t * 2 * math.pi
            track_point = 1.5 * np.array((math.cos(theta), math.sin(theta), 0))
            cur_rays = get_camera_rays_render(camera_pos_np, track_point=track_point, **resolution)
            cur_rays = torch.Tensor(cur_rays).float().cuda()
            cur_camera_pos = camera_pos
        else:
            raise ValueError(f'Unknown motion: {motion}')

        if transform is not None:  # Project camera back into canonical model coordinates
            cur_camera_pos = transform_points_torch(cur_camera_pos, transform)
            cur_rays = transform_points_torch(cur_rays, transform, translate=False)

        render, extras = trainer.render_image(z, cur_camera_pos, cur_rays, **render_kwargs)
        render = render.squeeze(0)
        render = render.cpu().numpy()
        render = (render * 255.).astype(np.uint8)
        imageio.imwrite(os.path.join(render_path, 'renders', f'{frame}.png'), render)

        if 'depth' in extras:
            depths = extras['depth'].squeeze(0).cpu().numpy()
            depths = (depths / render_kwargs['max_dist'] * 255.).astype(np.uint8)
            imageio.imwrite(os.path.join(render_path, 'depths', f'{frame}.png'), depths)

        if 'segmentation' in extras:
            pred_seg = extras['segmentation'].squeeze(0).cpu()
            colors = get_clustering_colors(pred_seg.shape[-1] + 1)
            pred_seg = pred_seg.argmax(-1).numpy() + 1
            pred_img = visualize_2d_cluster(pred_seg, colors)
            pred_img = (pred_img * 255.).astype(np.uint8)
            imageio.imwrite(os.path.join(render_path, 'segmentations', f'{frame}.png'), pred_img)


def process_scene(sceneid):
    render_path = os.path.join(out_dir, 'render', args.name, str(sceneid))
    if os.path.exists(render_path):
        print(f'Warning: Path {render_path} exists. Contents will be overwritten.')

    os.makedirs(render_path, exist_ok=True)
    subdirs = ['renders', 'depths', 'segmentations']
    for d in subdirs:
        os.makedirs(os.path.join(render_path, d), exist_ok=True)

    if isinstance(val_dataset, torch.utils.data.IterableDataset):
        data = next(val_iterator)
    else:
        data = val_dataset.__getitem__(sceneid)

    input_images = torch.Tensor(data['input_images']).to(device).unsqueeze(0)
    input_camera_pos = torch.Tensor(data['input_camera_pos']).to(device).unsqueeze(0)
    input_rays = torch.Tensor(data['input_rays']).to(device).unsqueeze(0)

    resolution = {'height': input_rays.shape[2],
                  'width': input_rays.shape[3]}
    if args.height is not None:
        resolution['height'] = args.height
    if args.width is not None:
        resolution['width'] = args.width

    if 'transform' in data:
        transform = torch.Tensor(data['transform']).to(device)
    else:
        transform = None

    for i in range(input_images.shape[1]):
        input_np = (np.transpose(data['input_images'][i], (1, 2, 0)) * 255.).astype(np.uint8)
        imageio.imwrite(os.path.join(render_path, f'input_{i}.png'), input_np)

    with torch.no_grad():
        z = model.encoder(input_images, input_camera_pos, input_rays)
        print('Rendering frames...')
        render3d(trainer, render_path, z, input_camera_pos[:, 0],
                 motion=args.motion, transform=transform, resolution=resolution, **render_kwargs)

    if not args.novideo:
        print('Compiling plot video...')
        compile_video_plot(render_path, frames=True, num_frames=args.num_frames)

if __name__ == '__main__':
    # Arguments
    parser = argparse.ArgumentParser(
        description='Render a video of a scene.'
    )
    parser.add_argument('config', type=str, help='Path to config file.')
    parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')
    parser.add_argument('--num-frames', type=int, default=360, help='Number of frames to render.')
    parser.add_argument('--sceneid', type=int, default=0, help='Id of the scene to render.')
    parser.add_argument('--sceneid-start', type=int, help='Id of the scene to render.')
    parser.add_argument('--sceneid-stop', type=int, help='Id of the scene to render.')
    parser.add_argument('--height', type=int, help='Rendered image height in pixels. Defaults to input image height.')
    parser.add_argument('--width', type=int, help='Rendered image width in pixels. Defaults to input image width.')
    parser.add_argument('--name', type=str, help='Name of this sequence.')
    parser.add_argument('--motion', type=str, default='rotate', help='Type of sequence.')
    parser.add_argument('--sharpen', action='store_true', help='Square density values for sharper surfaces.')
    parser.add_argument('--parallel', action='store_true', help='Wrap model in DataParallel.')
    parser.add_argument('--train', action='store_true', help='Use training data.')
    parser.add_argument('--fade', action='store_true', help='Add fade in/out.')
    parser.add_argument('--it', type=int, help='Iteration of the model to load.')
    parser.add_argument('--render-kwargs', type=str, help='Renderer kwargs as JSON dict')
    parser.add_argument('--novideo', action='store_true', help="Don't compile rendered images into video")

    args = parser.parse_args()
    with open(args.config, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.CLoader)
    print('configs loaded')
    is_cuda = (torch.cuda.is_available() and not args.no_cuda)
    device = torch.device("cuda" if is_cuda else "cpu")

    out_dir = os.path.dirname(args.config)
    exp_name = os.path.basename(out_dir)
    if args.name is None:
        args.name = args.motion
    if args.render_kwargs is not None:
        render_kwargs = json.loads(args.render_kwargs)
    else:
        render_kwargs = dict()

    model = OSRT(cfg['model']).to(device)
    model.eval()

    mode = 'train' if args.train else 'val'
    val_dataset = get_dataset(mode, cfg['data'])

    render_kwargs |= val_dataset.render_kwargs

    optimizer = optim.Adam(model.parameters())
    trainer = SRTTrainer(model, optimizer, cfg, device, out_dir, val_dataset.render_kwargs)

    checkpoint = Checkpoint(out_dir, encoder=model.encoder, decoder=model.decoder, optimizer=optimizer)
    if args.it is not None:
        load_dict = checkpoint.load(f'model_{args.it}.pt')
    else:
        load_dict = checkpoint.load('model.pt')

    if args.sceneid_start is None:
        args.sceneid_start =  args.sceneid
        args.sceneid_stop = args.sceneid + 1

    if isinstance(val_dataset, torch.utils.data.IterableDataset):
        val_dataset.skip(args.sceneid_start)
        val_iterator = iter(val_dataset)

    for i in range(args.sceneid_start, args.sceneid_stop):
        process_scene(i)