import os
import pandas as pd
import torch
import torch.optim as optim
import wandb as wdb

import common_dataset
import dataloader
from config_common import load_args
from common_dataset import load_data
from dataloader import load_data
from loss import masked_cos_sim, distance, masked_spectral_angle
from model_custom import Model_Common_Transformer
from model import RT_pred_model_self_attention_multi


def train(model, data_train, epoch, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
          wandb=None):
    losses_rt = 0.
    losses_int = 0.
    dist_rt_acc = 0.
    dist_int_acc = 0.
    model.train()
    for param in model.parameters():
        param.requires_grad = True
    if forward == 'both':
        for seq, charge, rt, intensity in data_train:
            rt, intensity = rt.float(), intensity.float()
            if torch.cuda.is_available():
                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
            pred_rt, pred_int = model.forward(seq, charge)
            loss_rt = criterion_rt(rt, pred_rt)
            loss_int = criterion_intensity(intensity, pred_int)
            loss = loss_rt + loss_int
            dist_rt = metric_rt(rt, pred_rt)
            dist_int = metric_intensity(intensity, pred_int)
            dist_rt_acc += dist_rt.item()
            dist_int_acc += dist_int.item()
            losses_rt += loss_rt.item()
            losses_int += 5.*loss_int.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if wandb is not None:
            wdb.log({"train rt loss": losses_rt / len(data_train), "train int loss": losses_int / len(data_train),
                     "train rt mean metric": dist_rt_acc / len(data_train),
                     "train int mean metric": dist_int_acc / len(data_train),
                     'train epoch': epoch})

        print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), 'train int loss',
              losses_int / len(data_train), "train rt mean metric : ", dist_rt_acc / len(data_train),
              "train int mean metric",
              dist_int_acc / len(data_train))

    if forward == 'rt':
        for seq, rt in data_train:
            rt = rt.float()
            if torch.cuda.is_available():
                seq, rt = seq.cuda(), rt.cuda()
            pred_rt = model.forward_rt(seq)
            loss_rt = criterion_rt(rt, pred_rt)
            loss = loss_rt
            dist_rt = metric_rt(rt, pred_rt)
            dist_rt_acc += dist_rt.item()
            losses_rt += loss_rt.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if wandb is not None:
            wdb.log({"train rt loss": losses_rt / len(data_train),
                     "train rt mean metric": dist_rt_acc / len(data_train),
                     'train epoch': epoch})

        print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), "train rt mean metric : ",
              dist_rt_acc / len(data_train))

    if forward == 'int':
        for seq, charge, intensity in data_train:
            intensity = intensity.float()
            if torch.cuda.is_available():
                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
            pred_int = model.forward_int(seq, charge)
            loss_int = criterion_intensity(intensity, pred_int)
            loss = loss_int
            dist_int = metric_intensity(intensity, pred_int)
            dist_int_acc += dist_int.item()
            losses_int += loss_int.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if wandb is not None:
            wdb.log({"train int loss": losses_int / len(data_train),
                     "train int mean metric": dist_int_acc / len(data_train),
                     'train epoch': epoch})

        print('epoch : ', epoch, 'train int loss',
              losses_int / len(data_train),
              "train int mean metric",
              dist_int_acc / len(data_train))


def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=None):
    losses_rt = 0.
    losses_int = 0.
    dist_rt_acc = 0.
    dist_int_acc = 0.
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    if forward == 'both':
        for seq, charge, rt, intensity in data_val:
            rt, intensity = rt.float(), intensity.float()
            if torch.cuda.is_available():
                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
            pred_rt, pred_int = model.forward(seq, charge)
            loss_rt = criterion_rt(rt, pred_rt)
            loss_int = criterion_intensity(intensity, pred_int)
            losses_rt += loss_rt.item()
            losses_int += loss_int.item()
            dist_rt = metric_rt(rt, pred_rt)
            dist_int = metric_intensity(intensity, pred_int)
            dist_rt_acc += dist_rt.item()
            dist_int_acc += dist_int.item()

        if wandb is not None:
            wdb.log({"val rt loss": losses_rt / len(data_val), "val int loss": losses_int / len(data_val),
                     "val rt mean metric": dist_rt_acc / len(data_val),
                     "val int mean metric": dist_int_acc / len(data_val),
                     'val epoch': epoch})

        print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val), 'val int loss', losses_int / len(data_val),
              "val rt mean metric : ",
              dist_rt_acc / len(data_val), "val int mean metric", dist_int_acc / len(data_val))

    if forward == 'rt':  #adapted to prosit dataset format
        for seq, rt in data_val:
            rt = rt.float()
            if torch.cuda.is_available():
                seq, rt = seq.cuda(), rt.cuda()
            pred_rt = model.forward_rt(seq)
            loss_rt = criterion_rt(rt, pred_rt)
            losses_rt += loss_rt.item()
            dist_rt = metric_rt(rt, pred_rt)
            dist_rt_acc += dist_rt.item()

        if wandb is not None:
            wdb.log({"val rt loss": losses_rt / len(data_val),
                     "val rt mean metric": dist_rt_acc / len(data_val),
                     'val epoch': epoch})

        print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val),
              "val rt mean metric : ",
              dist_rt_acc / len(data_val))

    if forward == 'int':  #adapted to prosit dataset format
        for seq, charge, _, intensity in data_val:
            intensity = intensity.float()
            if torch.cuda.is_available():
                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
            pred_int = model.forward_int(seq, charge)
            loss_int = criterion_intensity(intensity, pred_int)
            losses_int += loss_int.item()
            dist_int = metric_intensity(intensity, pred_int)
            dist_int_acc += dist_int.item()
        if wandb is not None:
            wdb.log({"val int loss": losses_int / len(data_val),
                     "val int mean metric": dist_int_acc / len(data_val),
                     'val epoch': epoch})
        print('epoch : ', epoch, 'val int loss', losses_int / len(data_val), "val int mean metric",
              dist_int_acc / len(data_val))


def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
        criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv', file=False):
    if forward =='transfer' :
        for e in range(1, epochs + 1):
            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt',
                  wandb=wandb)
            if e % eval_inter == 0:
                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both',
                     wandb=wandb)
            if e % save_inter == 0:
                save(model, 'model_common_' + str(e) + '.pt')
        save_pred(model, data_val, 'both', output)
    elif forward=='reverse':
        for e in range(1, epochs + 1):
            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both',
                  wandb=wandb)
            if e % eval_inter == 0:
                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt',
                     wandb=wandb)
            if e % save_inter == 0:
                save(model, 'model_common_' + str(e) + '.pt')
        save_pred(model, data_val, 'rt', output)

    else :
        for e in range(1, epochs + 1):
            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
                  wandb=wandb)
            if e % eval_inter == 0:
                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
                     wandb=wandb)
            # if e % save_inter == 0:
            #     save(model, 'model_common_' + str(e) + '.pt')
        save_pred(model, data_val, forward, output, file=file)


def main(args):
    if args.wandb is not None:
        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
        os.environ["WANDB_MODE"] = "offline"
        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")

        wdb.init(project="Common prediction", dir='./wandb_run', name=args.wandb)
    print(args)
    print('Cuda : ', torch.cuda.is_available())

    if args.forward == 'both':
        data_train, data_val = common_dataset.load_data(path_train=args.dataset_train,
                                                                   path_val=args.dataset_val,
                                                                   path_test=args.dataset_test,
                                                                   batch_size=args.batch_size, length=args.seq_length, pad = False, convert=False, vocab='unmod')
    elif args.forward == 'rt':
        data_train, data_val = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
                                                               batch_size=args.batch_size, length=args.seq_length)

    elif args.forward == 'transfer':
        data_train, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'],
                                                               batch_size=args.batch_size, length=args.seq_length)

        _, data_val = common_dataset.load_data(path_train=args.dataset_val,
                                                                   path_val=args.dataset_val,
                                                                   path_test=args.dataset_test,
                                                                   batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')

    elif args.forward == 'reverse':
        _, data_val = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test],
                                                               batch_size=args.batch_size, length=args.seq_length)

        data_train, _ = common_dataset.load_data(path_train=args.dataset_train,
                                                                   path_val=args.dataset_train,
                                                                   path_test=args.dataset_train,
                                                                   batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')

    print('\nData loaded')

    model = Model_Common_Transformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
                                     decoder_int_ff=args.decoder_int_ff
                                     , n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
                                     decoder_int_num_layer=args.decoder_int_num_layer,
                                     decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
                                     embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first,
                                     seq_length=args.seq_length)
    if torch.cuda.is_available():
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    print('\nModel initialised')

    run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train,
        data_val=data_val, data_test=data_val, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(),
        criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle,
        wandb=args.wandb, forward=args.forward, output=args.output, file=args.file)

    if args.wandb is not None:
        wdb.finish()


def save(model, checkpoint_name):
    print('\nModel Saving...')
    os.makedirs('checkpoints', exist_ok=True)
    torch.save(model, os.path.join('checkpoints', checkpoint_name))


def load(path):
    model = torch.load(os.path.join('checkpoints', path))
    return model


def get_n_params(model):
    pp = 0
    for n, p in list(model.named_parameters()):
        nn = 1

        for s in list(p.size()):
            nn = nn * s
        pp += nn
    return pp

def save_pred(model, data_val, forward, output_path, file = False):

    data_frame = pd.DataFrame()
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    if forward == 'both':
        pred_rt, pred_int, seqs, charges, true_rt, true_int, file_list = [], [], [], [], [], [], []
        if file:
            data_val.dataset.set_file_mode(True)
            for seq, charge, rt, intensity, files in data_val:
                rt, intensity = rt.float(), intensity.float()
                if torch.cuda.is_available():
                    seq, charge, rt, intensity, files = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda(), files.cuda()
                pr_rt, pr_intensity = model.forward(seq, charge)
                pred_rt.extend(pr_rt.data.cpu().tolist())
                pred_int.extend(pr_intensity.data.cpu().tolist())
                seqs.extend(seq.data.cpu().tolist())
                charges.extend(charge.data.cpu().tolist())
                true_rt.extend(rt.data.cpu().tolist())
                true_int.extend(intensity.data.cpu().tolist())
                file_list.extend(files.data.cpu().tolist())
        else :
            for seq, charge, rt, intensity in data_val:
                rt, intensity = rt.float(), intensity.float()
                if torch.cuda.is_available():
                    seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
                pr_rt, pr_intensity = model.forward(seq, charge)
                pred_rt.extend(pr_rt.data.cpu().tolist())
                pred_int.extend(pr_intensity.data.cpu().tolist())
                seqs.extend(seq.data.cpu().tolist())
                charges.extend(charge.data.cpu().tolist())
                true_rt.extend(rt.data.cpu().tolist())
                true_int.extend(intensity.data.cpu().tolist())

        data_frame['rt pred'] = pred_rt
        data_frame['seq'] = seqs
        data_frame['pred int'] = pred_int
        data_frame['true rt'] = true_rt
        data_frame['true int'] = true_int
        data_frame['charge'] = charges
        if file :
            data_frame['file'] = file_list



    if forward == 'rt':  #adapted to prosit dataset format
        pred_rt, seqs, true_rt = [], [], []
        for seq, rt in data_val:
            rt = rt.float()
            if torch.cuda.is_available():
                seq, rt = seq.cuda(), rt.cuda()
            pr_rt = model.forward_rt(seq)
            pred_rt.extend(pr_rt.data.cpu().tolist())
            seqs.extend(seq.data.cpu().tolist())
            true_rt.extend(rt.data.cpu().tolist())
        data_frame['rt pred'] = pred_rt
        data_frame['seq'] = seqs
        data_frame['true rt'] = true_rt


    if forward == 'int':  #adapted to prosit dataset format
        pred_int, seqs, charges, true_int = [], [], [], []
        for seq, charge, _, intensity in data_val:
            intensity = intensity.float()
            if torch.cuda.is_available():
                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
            pred_int = model.forward_int(seq, charge)
            seqs.extend(seq.data.cpu().tolist())
            charges.extend(charge.data.cpu().tolist())
            true_int.extend(intensity.data.cpu().tolist())
        data_frame['seq'] = seqs
        data_frame['pred int'] = pred_int
        data_frame['true int'] = true_int
        data_frame['charge'] = charges
    if file :
        data_val.dataset.set_file_mode(False)
    data_frame.to_csv(output_path)


if __name__ == "__main__":
    args = load_args()
    main(args)