import datetime
import time
import torch
import torch.nn as nn
import argparse
import yaml

from osrt.model import LitSlotAttentionAutoEncoder
from osrt import data
from osrt.utils.visualize import visualize_slot_attention

from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

import lightning as pl
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint


def main():
    # Arguments
    parser = argparse.ArgumentParser(
        description='Train a 3D scene representation model.'
    )
    parser.add_argument('config', type=str, help="Where to save the checkpoints.")
    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')

    args = parser.parse_args()
    with open(args.config, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.CLoader)

    ### Set random seed.
    pl.seed_everything(42, workers=True)

    ### Hyperparameters of the model.
    batch_size = cfg["training"]["batch_size"]
    num_gpus = cfg["training"]["num_gpus"]
    num_slots = cfg["model"]["num_slots"]
    num_iterations = cfg["model"]["iters"]
    num_train_steps = cfg["training"]["max_it"]
    warmup_steps = cfg["training"]["warmup_it"]
    decay_rate = cfg["training"]["decay_rate"]
    decay_steps = cfg["training"]["decay_it"]
    resolution = (128, 128)
    
    #### Create datasets
    train_dataset = data.get_dataset('train', cfg['data'])
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
        shuffle=True, worker_init_fn=data.worker_init_fn)
    
    val_dataset = data.get_dataset('val', cfg['data'])
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, num_workers=1,
        shuffle=True, worker_init_fn=data.worker_init_fn)

    #### Create model
    model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)


    checkpoint_callback = ModelCheckpoint(
        save_top_k=10,
        monitor="val_psnr",
        mode="max",
        dirpath="./checkpoints" if cfg["model"]["model_type"] == "sa" else "./checkpoints_tsa",
        filename="slot_att-clevr3d-{epoch:02d}-psnr{val_psnr:.2f}.pth",
    )

    trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", 
                         default_root_dir="./logs", logger=WandbLogger(project="slot-attention") if args.wandb else None,
                         strategy="ddp" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], deterministic=True,
                         log_every_n_steps=100, max_steps=num_train_steps)

    trainer.fit(model, train_loader, val_loader)
                
if __name__ == "__main__":
    main()


#print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")

"""

if not epoch % cfg["training"]["checkpoint_every"]:
            # Save the checkpoint of the model.
            ckpt['global_step'] = global_step
            ckpt['model_state_dict'] = model.state_dict()
            torch.save(ckpt, args.ckpt + '/ckpt_' + str(global_step) + '.pth')
            print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}")

        # We visualize some test data
        if not epoch % cfg["training"]["visualize_every"]:
            
"""