import os

import torch
import torch.optim as optim
# import wandb as wdb
import numpy as np

from config import load_args
from dataloader import load_data, load_intensity_from_files
from loss import masked_cos_sim, distance, masked_spectral_angle
from model import (RT_pred_model_self_attention, Intensity_pred_model_multi_head, RT_pred_model_self_attention_pretext,
                   RT_pred_model_self_attention_multi, RT_pred_model_self_attention_multi_sum,
                   RT_pred_model_transformer)


# from torcheval.metrics import R2Score


# def compute_metrics(model, data_val, f_name):
#     name = os.path.join('checkpoints', f_name)
#     model.load_state_dict(torch.load(name))
#     model.eval()
#     targets = []
#     preds = []
#     r2 = R2Score()
#     for data, target in data_val:
#         targets.append(target)
#         pred = model(data)
#         preds.append(pred)
#     full_target = torch.concat(targets, dim=0)
#     full_pred = torch.concat(preds, dim=0)
#
#     r2.update(full_pred, full_target)
#     diff = torch.abs(full_target - full_pred)
#     sorted_diff, _ = diff.sort()
#     delta_95 = sorted_diff[int(np.floor(sorted_diff.size(dim=0) * 0.95))].item()
#     score = r2.compute()
#     return score, delta_95




def train_rt(model, data_train, epoch, optimizer, criterion, metric, wandb=None):
    losses = 0.
    dist_acc = 0.
    model.train()
    for param in model.parameters():
        param.requires_grad = True
    for data, target in data_train:
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()
        pred_rt = model.forward(data)
        target.float()
        loss = criterion(pred_rt, target)
        dist = metric(pred_rt, target)
        dist_acc += dist.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses += loss.item()

    if wandb is not None:
        wdb.log({"train loss": losses / len(data_train), "train mean metric": dist_acc / len(data_train),
                 'train epoch': epoch})

    print('epoch : ', epoch, ',train losses : ', losses / len(data_train), " ,mean metric : ",
          dist_acc / len(data_train))


def train_pretext(model, data_train, epoch, optimizer, criterion, task, metric, coef, wandb=None):
    losses, losses_2 = 0., 0.
    dist_acc = 0.
    model.train()
    for param in model.parameters():
        param.requires_grad = True
    for data, target in data_train:
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()
        pred_rt, pred_seq = model.forward(data)
        pred_seq = pred_seq.transpose(1, 2)
        target.float()
        loss = criterion(pred_rt, target)
        loss_2 = task(pred_seq, data)
        losses_2 += loss_2.item()
        loss_tot = loss + coef * loss_2
        dist = metric(pred_rt, target)
        dist_acc += dist.item()
        optimizer.zero_grad()
        loss_tot.backward()
        optimizer.step()
        losses += loss.item()

    if wandb is not None:
        wdb.log({"train loss": losses / len(data_train), "train loss pretext": losses_2 / len(data_train),
                 "train mean metric": dist_acc / len(data_train), 'train epoch': epoch})

    print('epoch : ', epoch, ',train losses : ', losses / len(data_train), ',train pretext losses : ',
          losses_2 / len(data_train), " ,mean metric : ",
          dist_acc / len(data_train))


def train_int(model, data_train, epoch, optimizer, criterion, metric, wandb=None):
    losses = 0.
    dist_acc = 0.
    model.train()
    for param in model.parameters():
        param.requires_grad = True
    for data1, data2, data3, target in data_train:
        if torch.cuda.is_available():
            data1, data2, data3, target = data1.cuda(), data2.cuda(), data3.cuda(), target.cuda()
        pred_rt = model.forward(data1, data2, data3)
        target.float()
        loss = criterion(pred_rt, target)
        dist = metric(pred_rt, target)
        dist_acc += dist.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses += loss.item()

    if wandb is not None:
        wdb.log({"train loss": losses / len(data_train), "train mean metric": dist_acc / len(data_train),
                 'train epoch': epoch})

    print('epoch : ', epoch, 'train losses : ', losses / len(data_train), " mean metric : ",
          dist_acc / len(data_train))


def eval_int(model, data_val, epoch, criterion, metric, wandb=None):
    losses = 0.
    dist_acc = 0.
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    for data1, data2, data3, target in data_val:
        if torch.cuda.is_available():
            data1, data2, data3, target = data1.cuda(), data2.cuda(), data3.cuda(), target.cuda()
        pred_rt = model.forward(data1, data2, data3)
        loss = criterion(pred_rt, target)
        losses += loss.item()
        dist = metric(pred_rt, target)
        dist_acc += dist.item()

    if wandb is not None:
        wdb.log({"eval loss": losses / len(data_val), 'eval epoch': epoch, "eval metric": dist_acc / len(data_val)})
    print('epoch : ', epoch, ',eval losses : ', losses / len(data_val), " ,eval mean metric: :",
          dist_acc / len(data_val))
    return losses / len(data_val)


def eval_rt(model, data_val, epoch, criterion, metric, wandb=None):
    losses = 0.
    dist_acc = 0.
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    for data, target in data_val:
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()
        pred_rt = model(data)
        loss = criterion(pred_rt, target)
        losses += loss.item()
        dist = metric(pred_rt, target)
        dist_acc += dist.item()

    if wandb is not None:
        wdb.log({"eval loss": losses / len(data_val), 'eval epoch': epoch, "eval metric": dist_acc / len(data_val)})
    print('epoch : ', epoch, ',eval losses : ', losses / len(data_val), " ,eval mean metric: :",
          dist_acc / len(data_val))

    return dist_acc / len(data_val)


def eval_pretext(model, data_val, epoch, criterion, metric, wandb=None):
    losses = 0.
    dist_acc = 0.
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    for data, target in data_val:
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()
        pred_rt, _ = model(data)
        loss = criterion(pred_rt, target)
        losses += loss.item()
        dist = metric(pred_rt, target)
        dist_acc += dist.item()

    if wandb is not None:
        wdb.log({"eval loss": losses / len(data_val), 'eval epoch': epoch, "eval metric": dist_acc / len(data_val)})
    print('epoch : ', epoch, ',eval losses : ', losses / len(data_val), " ,eval mean metric:",
          dist_acc / len(data_val))

    return dist_acc / len(data_val)


def save(model, optimizer, epoch, checkpoint_name):
    print('\nModel Saving...')
    os.makedirs('checkpoints', exist_ok=True)
    torch.save(model, os.path.join('checkpoints', checkpoint_name))


def load(path):
    model = torch.load(os.path.join('checkpoints', path))
    return model


def run_rt(epochs, eval_inter, save_inter, model, data_train, data_val, optimizer, criterion, metric, wandb=None):
    for e in range(1, epochs + 1):
        train_rt(model, data_train, e, optimizer, criterion, metric, wandb=wandb)
        if e % eval_inter == 0:
            eval_rt(model, data_val, e, criterion, metric, wandb=wandb)
        if e % save_inter == 0:
            save(model, optimizer, epochs, 'model_self_attention_' + str(e) + '.pt')


def run_pretext(epochs, eval_inter, model, data_train, data_val, data_test, optimizer, criterion, task, metric, coef,
                wandb=None):
    best_dist = 10000
    best_epoch = 0
    for e in range(1, epochs + 1):
        train_pretext(model, data_train, e, optimizer, criterion, task, metric, coef, wandb=wandb)
        if e % eval_inter == 0:
            dist = eval_pretext(model, data_val, e, criterion, metric, wandb=wandb)
            if dist < best_dist:
                best_epoch = e
                if wandb is not None:
                    save(model, optimizer, epochs, 'model_self_attention_pretext_' + wandb + '.pt')
                else:
                    save(model, optimizer, epochs, 'model_self_attention_pretext.pt')

    if wandb is not None:
        model_final = load('model_self_attention_pretext_' + wandb + '.pt')
    else:
        model_final = load('model_self_attention_pretext.pt')
    eval_pretext(model_final, data_test, 0, criterion, metric, wandb=wandb)
    print('Best epoch : ' + str(best_epoch))


def run_int(epochs, eval_inter, save_inter, model, data_train, data_val, optimizer, criterion, metric,
            wandb=None):
    for e in range(1, epochs + 1):
        best_loss = 10000
        best_epoch = 0
        train_int(model, data_train, e, optimizer, criterion, metric, wandb=wandb)
        if e % eval_inter == 0:
            loss = eval_int(model, data_val, e, criterion, metric, wandb=wandb)
        #     if loss < best_loss:
        #         best_epoch = e
        #         if wandb is not None:
        #             save(model, optimizer, epochs, 'model_int' + wandb + '.pt')
        #         else:
        #             save(model, optimizer, epochs, 'model_int.pt')
        # if wandb is not None:
        #     model_final = load('model_int' + wandb + '.pt')
        # else:
        #     model_final = load('model_int.pt')
        # print('Best epoch : ',e)


def main_rt(args):
    if args.wandb is not None:
        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
        os.environ["WANDB_MODE"] = "offline"
        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")

        wdb.init(project="RT prediction", dir='./wandb_run', name=args.wandb)
    print(args)
    print('Cuda : ', torch.cuda.is_available())
    if args.dataset_train == args.dataset_test:
        data_train, data_val, data_test = load_data(batch_size=args.batch_size, n_train=args.n_train, n_test=args.n_test,
                                                    data_sources=[args.dataset_train, args.dataset_train, args.dataset_train])
    else:
        data_train, data_val, data_test = load_data(batch_size=args.batch_size, n_train=args.n_train, n_test=args.n_test,
                                                    data_sources=[args.dataset_train,args.dataset_train,args.dataset_test])
    print('\nData loaded')
    # if args.model == 'RT_self_att' :
    #     model = RT_pred_model_self_attention()
    if args.model == 'RT_multi':
        model = RT_pred_model_self_attention_multi(recurrent_layers_sizes=(args.layers_sizes[0],args.layers_sizes[1],args.layers_size[2]), regressor_layer_size=args.layers_sizes[3])
    if args.model == 'RT_self_att' or args.model == 'RT_multi_sum':
        model = RT_pred_model_self_attention_multi_sum(n_head=args.n_head, recurrent_layers_sizes=(args.layers_sizes[0],args.layers_sizes[1]), regressor_layer_size=args.layers_sizes[2])
    if args.model == 'RT_transformer':
        model = RT_pred_model_transformer(regressor_layer_size=args.layers_sizes[2])
    if torch.cuda.is_available():
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    print('\nModel initialised')
    run_rt(args.epochs, args.eval_inter, args.save_inter, model, data_train, data_val, optimizer=optimizer,
           criterion=torch.nn.MSELoss(), metric=distance, wandb=args.wandb)

    if args.wandb is not None:
        wdb.finish()


def main_pretext(args):
    if args.wandb is not None:
        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
        os.environ["WANDB_MODE"] = "offline"
        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")

        wdb.init(project="RT prediction", dir='./wandb_run', name=args.wandb)
    print(args)
    print('Cuda : ', torch.cuda.is_available())
    if args.dataset_train == args.dataset_test:
        data_train, data_val, data_test = load_data(args.batch_size, args.n_train, args.n_test,
                                                    data_source=args.dataset_train)
    else:
        data_train, _, _ = load_data(args.batch_size, args.n_train, args.n_test,
                                     data_source=args.dataset_train)
        _, data_val, data_test = load_data(args.batch_size, args.n_train, args.n_test,
                                           data_source=args.dataset_test)
    print('\nData loaded')
    model = RT_pred_model_self_attention_pretext(recurrent_layers_sizes=(args.layers_sizes[0],args.layers_sizes[1]), regressor_layer_size=args.layers_sizes[2])
    if torch.cuda.is_available():
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    print('\nModel initialised')
    run_pretext(args.epochs, args.eval_inter, model, data_train, data_val, data_test, optimizer=optimizer,
                criterion=torch.nn.MSELoss(), task=torch.nn.CrossEntropyLoss(), metric=distance, coef=args.coef_pretext,
                wandb=args.wandb)

    if args.wandb is not None:
        wdb.finish()


def main_int(args):
    if args.wandb is not None:
        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'

        os.environ["WANDB_MODE"] = "offline"
        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")

        wdb.init(project="Intensity prediction", dir='./wandb_run', name=args.wandb)
    print(args)
    print(torch.cuda.is_available())

    sources_train = ('data/intensity/sequence_train.npy',
                     'data/intensity/intensity_train.npy',
                     'data/intensity/collision_energy_train.npy',
                     'data/intensity/precursor_charge_train.npy')

    sources_test = ('data/intensity/sequence_test.npy',
                    'data/intensity/intensity_test.npy',
                    'data/intensity/collision_energy_test.npy',
                    'data/intensity/precursor_charge_test.npy')

    data_train = load_intensity_from_files(sources_train[0], sources_train[1], sources_train[2], sources_train[3],
                                           args.batch_size)
    data_val = load_intensity_from_files(sources_test[0], sources_test[1], sources_test[2], sources_test[3],
                                         args.batch_size)

    print('\nData loaded')
    model = Intensity_pred_model_multi_head(recurrent_layers_sizes=(args.layers_sizes[0],args.layers_sizes[1]), regressor_layer_size=args.layers_sizes[2])
    if torch.cuda.is_available():
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    print('\nModel initialised')
    run_int(args.epochs, args.eval_inter, args.save_inter, model, data_train, data_val, optimizer=optimizer,
            criterion=masked_cos_sim, metric=masked_spectral_angle, wandb=args.wandb)

    if args.wandb is not None:
        wdb.finish()


if __name__ == "__main__":
    args = load_args()

    if args.model == 'RT_self_att' or args.model == 'RT_multi' or args.model == 'RT_multi_sum' or args.model == 'RT_transformer':
        main_rt(args)
    elif args.model == 'Intensity_multi_head':
        main_int(args)
    elif args.model == 'RT_pretext':
        main_pretext(args)