import torch
import torch.nn as nn
import argparse
import yaml

from osrt.model import LitSlotAttentionAutoEncoder
from osrt import data
from osrt.utils.visualization_utils import visualize_slot_attention
from osrt.utils.common import mse2psnr

from torch.utils.data import DataLoader
import torch.nn.functional as F
import os

import lightning as pl
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint

import warnings
from lightning.pytorch.utilities.warnings import PossibleUserWarning
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

# Ignore all warnings that could be false positives : cf https://lightning.ai/docs/pytorch/stable/advanced/speed.html
warnings.filterwarnings("ignore", category=PossibleUserWarning)

def main():
    # Arguments
    parser = argparse.ArgumentParser(
        description='Train a 3D scene representation model.'
    )
    parser.add_argument('config', type=str, help="Where to save the checkpoints.")
    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--ckpt', type=str, default=None, help='Model checkpoint path')

    args = parser.parse_args()
    with open(args.config, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.CLoader)

    ### Set random seed.
    pl.seed_everything(42, workers=True)

    ### Hyperparameters of the model.
    batch_size = cfg["training"]["batch_size"]
    num_gpus = cfg["training"]["num_gpus"]
    num_slots = cfg["model"]["num_slots"]
    num_iterations = cfg["model"]["iters"]
    num_train_steps = cfg["training"]["max_it"]
    resolution = (128, 128)
    
    print(f"Number of CPU Cores : {os.cpu_count()}")

    #### Create datasets
    train_dataset = data.get_dataset('train', cfg['data'])
    val_every = val_every // len(train_dataset)
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"]-9,
        shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
    
    val_dataset = data.get_dataset('val', cfg['data'])
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, num_workers=8,
        shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)

    #### Create model
    model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)

    if args.ckpt:
        checkpoint = torch.load(args.ckpt)
        model.load_state_dict(checkpoint['state_dict'])

    checkpoint_callback = ModelCheckpoint(
        monitor="val_psnr",
        mode="max",
        dirpath="./checkpoints",
        filename="ckpt-" +  str(cfg["data"]["dataset"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}",
        save_weights_only=True, # don't save optimizer states nor lr-scheduler, ...
        every_n_train_steps=cfg["training"]["checkpoint_every"]
    )

    early_stopping = EarlyStopping(monitor="val_psnr", mode="max")

    trainer = pl.Trainer(accelerator="gpu", 
                         devices=num_gpus, 
                         profiler="simple", 
                         default_root_dir="./logs", 
                         logger=WandbLogger(project="slot-att") if args.wandb else None,
                         strategy="ddp" if num_gpus > 1 else "auto", 
                         callbacks=[checkpoint_callback, early_stopping],
                         log_every_n_steps=100, 
                         val_check_interval=cfg["training"]["validate_every"],
                         max_steps=num_train_steps, 
                         enable_model_summary=True)

    trainer.fit(model, train_loader, val_loader)

    #### Create datasets
    vis_dataset = data.get_dataset('train', cfg['data'])
    vis_loader = DataLoader(
        vis_dataset, batch_size=1, num_workers=1,
        shuffle=True, worker_init_fn=data.worker_init_fn)
    
    device = model.device

    #### Visualize image
    image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
    image = F.interpolate(image, size=128)
    image = image.to(device)
    recon_combined, recons, masks, slots, _ = model(image)
    loss = nn.MSELoss()
    loss_value = loss(recon_combined, image)
    psnr = mse2psnr(loss_value)
    visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True)
                
if __name__ == "__main__":
    main()