import os
import tempfile

import torch
import torch.optim as optim
from ray.air import RunConfig, CheckpointConfig
from ray.tune.search.ax import AxSearch
from ray.tune.search.bayesopt import BayesOptSearch
from ray.tune.search.bohb import TuneBOHB
from ray.tune.search.optuna import OptunaSearch
from ray.util.client import ray

import common_dataset
import dataloader
from config_common import load_args
from loss import masked_cos_sim
from model_custom import Model_Common_Transformer
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import HyperBandForBOHB, ASHAScheduler


def train_model(config, args):
    net = Model_Common_Transformer(encoder_ff=int(config["encoder_ff"]),
                                   decoder_rt_ff=int(config["decoder_rt_ff"]),
                                   decoder_int_ff=int(config["decoder_int_ff"]),
                                   n_head=int(config["n_head"]),
                                   encoder_num_layer=int(config["encoder_num_layer"]),
                                   decoder_int_num_layer=int(config["decoder_int_num_layer"]),
                                   decoder_rt_num_layer=int(config["decoder_rt_num_layer"]),
                                   drop_rate=float(config["drop_rate"]),
                                   embedding_dim=int(config["embedding_dim"]),
                                   acti=config["activation"],
                                   norm=config["norm_first"])

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            print(type(net))
            net = torch.nn.DataParallel(net)
            print(type(net))
    net.to(device)

    criterion_rt = torch.nn.MSELoss()
    criterion_intensity = masked_cos_sim
    optimizer = optim.Adam(net.parameters(), lr=config["lr"])

    # Load existing checkpoint through `get_checkpoint()` API.
    if train.get_checkpoint():
        loaded_checkpoint = train.get_checkpoint()
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
            model_state, optimizer_state = torch.load(
                os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
            )
            net.load_state_dict(model_state)
            optimizer.load_state_dict(optimizer_state)

    if args.forward == 'both':
        data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
                                                                   path_val=args.dataset_val,
                                                                   path_test=args.dataset_test,
                                                                   batch_size=int(config["batch_size"]), length=25)
    else:
        data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train,
                                                               batch_size=int(config["batch_size"]), length=25)

    for epoch in range(100):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(data_train):

            if args.forward == 'rt':
                seq, rt = data
                rt = rt.float()
                if torch.cuda.is_available():
                    seq, rt = seq.cuda(), rt.cuda()

                if torch.cuda.device_count() > 1:
                    pred_rt = net.module.forward_rt(seq)
                else:
                    pred_rt = net.forward_rt(seq)

                loss = criterion_rt(rt, pred_rt)

            elif args.forward == 'int':
                seq, charge, intensity = data
                intensity = intensity.float()
                if torch.cuda.is_available():
                    seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()

                if torch.cuda.device_count() > 1:
                    pred_int = net.module.forward_int(seq, charge)
                else:
                    pred_int = net.forward_int(seq, charge)

                loss = criterion_intensity(intensity, pred_int)

            else:
                seq, charge, rt, intensity = data
                rt, intensity = rt.float(), intensity.float()
                if torch.cuda.is_available():
                    seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
                pred_rt, pred_int = net(seq, charge)
                loss_rt = criterion_rt(rt, pred_rt)
                loss_int = criterion_intensity(intensity, pred_int)
                loss = loss_rt + loss_int

            running_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # print statistics

            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
                                                running_loss / epoch_steps))
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        for i, data in enumerate(data_val, 0):
            with torch.no_grad():
                if args.forward == 'rt':
                    seq, rt = data
                    rt = rt.float()
                    if torch.cuda.is_available():
                        seq, rt = seq.cuda(), rt.cuda()

                    if torch.cuda.device_count() > 1:
                        pred_rt = net.module.forward_rt(seq)
                    else:
                        pred_rt = net.forward_rt(seq)

                    loss = criterion_rt(rt, pred_rt)

                elif args.forward == 'int':
                    seq, charge, intensity = data
                    intensity = intensity.float()
                    if torch.cuda.is_available():
                        seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()

                    if torch.cuda.device_count() > 1:
                        pred_int = net.module.forward_int(seq, charge)
                    else:
                        pred_int = net.forward_int(seq, charge)

                    loss = criterion_intensity(intensity, pred_int)

                else:
                    seq, charge, rt, intensity = data
                    rt, intensity = rt.float(), intensity.float()
                    if torch.cuda.is_available():
                        seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
                    pred_rt, pred_int = net(seq, charge)
                    loss_rt = criterion_rt(rt, pred_rt)
                    loss_int = criterion_intensity(intensity, pred_int)
                    loss = loss_rt + loss_int
                val_loss += loss.item().numpy()
                val_steps += 1

        # Here we save a checkpoint. It is automatically registered with
        # Ray Tune and will potentially be accessed through in ``get_checkpoint()``
        # in future iterations.
        # Note to save a file like checkpoint, you still need to put it under a directory
        # to construct a checkpoint.
        with tempfile.TemporaryDirectory(
                dir='/gpfswork/rech/ute/ucg81ws/these/LC-MS-RT-prediction/checkpoints') as temp_checkpoint_dir:
            path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")

            torch.save(
                (net.state_dict(), optimizer.state_dict()), path
            )
            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
            print(checkpoint.path)
            train.report(
                {"loss": (val_loss / val_steps)},
                checkpoint=checkpoint,
            )
    print("Finished Training")


def test_best_model(best_result, args):
    best_trained_model = Model_Common_Transformer(encoder_ff=best_result.config["encoder_ff"],
                                                  decoder_rt_ff=best_result.config["decoder_rt_ff"],
                                                  decoder_int_ff=best_result.config["decoder_int_ff"]
                                                  , n_head=best_result.config["n_head"],
                                                  encoder_num_layer=best_result.config["batch_size"],
                                                  decoder_int_num_layer=best_result.config["decoder_int_num_layer"],
                                                  decoder_rt_num_layer=best_result.config["decoder_rt_num_layer"],
                                                  drop_rate=best_result.config["drop_rate"],
                                                  embedding_dim=best_result.config["embedding_dim"],
                                                  acti=best_result.config["activation"],
                                                  norm=best_result.config["norm_first"])

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            best_trained_model = torch.nn.DataParallel(best_trained_model)

    best_trained_model.to(device)
    criterion_rt = torch.nn.MSELoss()
    criterion_intensity = masked_cos_sim
    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")

    model_state, optimizer_state = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(model_state)

    if args.forward == 'both':
        data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
                                                                   path_val=args.dataset_val,
                                                                   path_test=args.dataset_test,
                                                                   batch_size=best_result.config["batch_size"],
                                                                   length=25)
    else:
        data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train,
                                                               batch_size=best_result.config["batch_size"], length=25)
    val_loss = 0
    val_steps = 0
    with torch.no_grad():
        for data in data_test:
            if args.forward == 'rt':
                seq, rt = data
                rt = rt.float()
                if torch.cuda.is_available():
                    seq, rt = seq.cuda(), rt.cuda()

                if torch.cuda.device_count() > 1:
                    pred_rt = best_trained_model.module.forward_rt(seq)
                else:
                    pred_rt = best_trained_model.forward_rt(seq)

                loss = criterion_rt(rt, pred_rt)

            elif args.forward == 'int':
                seq, charge, intensity = data
                intensity = intensity.float()
                if torch.cuda.is_available():
                    seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()

                if torch.cuda.device_count() > 1:
                    pred_int = best_trained_model.module.forward_int(seq, charge)
                else:
                    pred_int = best_trained_model.forward_int(seq, charge)

                loss = criterion_intensity(intensity, pred_int)

            elif args.forward == 'both':
                seq, charge, rt, intensity = data
                rt, intensity = rt.float(), intensity.float()
                if torch.cuda.is_available():
                    seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
                pred_rt, pred_int = best_trained_model(seq, charge)
                loss_rt = criterion_rt(rt, pred_rt)
                loss_int = criterion_intensity(intensity, pred_int)
                loss = loss_rt + loss_int
            val_loss += loss.item().numpy()
            val_steps += 1
    print("Best trial test set AsyncHyperBandSchedulerloss: {}".format(val_loss))


def main(args, gpus_per_trial=1):
    # config = {
    #     "encoder_num_layer": tune.choice([1]),
    #     "decoder_rt_num_layer": tune.choice([1]),
    #     "decoder_int_num_layer": tune.choice([1]),
    #     "embedding_dim": tune.choice([16, 64, 256, 1024]),
    #     "encoder_ff": tune.choice([512]),
    #     "decoder_rt_ff": tune.choice([512]),
    #     "decoder_int_ff": tune.choice([512]),
    #     "n_head": tune.choice([1]),
    #     "drop_rate": tune.choice([0.2]),
    #     "lr": tune.choice([1e-4]),
    #     "batch_size": tune.choice([1024]),
    # }
    config = {
        "encoder_num_layer": tune.choice([2, 4, 8]),
        "decoder_rt_num_layer": tune.choice([2, 4, 8]),
        "decoder_int_num_layer": tune.choice([1]),
        "embedding_dim": tune.choice([16, 64]),
        "encoder_ff": tune.choice([512, 1024, 2048]),
        "decoder_rt_ff": tune.choice([512, 1024, 2048]),
        "decoder_int_ff": tune.choice([512]),
        "n_head": tune.choice([1, 2, 4, 8, 16]),
        "drop_rate": tune.choice([0.25]),
        "lr": tune.loguniform(1e-4, 1e-2),
        "batch_size": tune.choice([4096]),
        "activation": tune.choice(['relu', 'gelu']),
        "norm_first": tune.choice([True, False]),
    }
    scheduler = ASHAScheduler(
        max_t=100,
        grace_period=30,
        reduction_factor=3,
        brackets=1,
    )
    algo = OptunaSearch()

    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(train_model, args=args),
            resources={"cpu": 80, "gpu": gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            time_budget_s=3600 * 23,
            search_alg=algo,
            scheduler=scheduler,
            num_samples=20,
            metric='loss',
            mode='min',


        ),
        run_config=RunConfig(storage_path="/gpfswork/rech/ute/ucg81ws/these/LC-MS-RT-prediction/ray_results_test",
                             name="test_experiment_no_scheduler"
                             ),
        param_space=config

    )
    results = tuner.fit()

    best_result = results.get_best_result("loss", "min")

    print("Best trial config: {}".format(best_result.config))
    print("Best trial final validation loss: {}".format(
        best_result.metrics["loss"]))
    print("Best trial final validation accuracy: {}".format(
        best_result.metrics["accuracy"]))

    test_best_model(best_result, args)


if __name__ == "__main__":
    for i in range(torch.cuda.device_count()):
        print(torch.cuda.get_device_properties(i).name)
    torch.manual_seed(2809)
    arg = load_args()
    main(arg, gpus_per_trial=4)