import os
import wandb as wdb
import matplotlib.pyplot as plt
import numpy as np

from config import load_args_contrastive
from dataset_ref import load_data_duo
import torch
import torch.nn as nn
from model import Classification_model_duo_contrastive
import torch.optim as optim
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd


def train_duo(model, data_train, optimizer, loss_function, epoch, wandb):
    model.train()
    losses = 0.
    acc = 0.
    for param in model.parameters():
        param.requires_grad = True

    for imaer, imana, img_ref, label in data_train:
        label = label.long()
        if torch.cuda.is_available():
            imaer = imaer.cuda()
            imana = imana.cuda()
            img_ref = img_ref.cuda()
            label = label.cuda()
        pred_logits = model.forward(imaer, imana, img_ref)
        pred_class = torch.argmax(pred_logits, dim=1)
        acc += (pred_class == label).sum().item()
        loss = loss_function(pred_logits, label)
        losses += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    losses = losses / len(data_train.dataset)
    acc = acc / len(data_train.dataset)
    print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch, losses, acc))

    if wandb is not None:
        wdb.log({"train loss": losses, 'train epoch': epoch, "train contrastive accuracy": acc})

    return losses, acc


def val_duo(model, data_test, loss_function, epoch, wandb):
    model.eval()
    losses = 0.
    acc = 0.
    acc_contrastive = 0.
    for param in model.parameters():
        param.requires_grad = False

    for imaer, imana, img_ref, label in data_test:
        imaer = imaer.transpose(0, 1)
        imana = imana.transpose(0, 1)
        img_ref = img_ref.transpose(0, 1)
        label = label.transpose(0, 1)
        label = label.squeeze()
        label = label.long()
        if torch.cuda.is_available():
            imaer = imaer.cuda()
            imana = imana.cuda()
            img_ref = img_ref.cuda()
            label = label.cuda()
        label_class = torch.argmin(label).data.cpu().numpy()
        pred_logits = model.forward(imaer, imana, img_ref)
        pred_class = torch.argmax(pred_logits[:, 0]).tolist()
        acc_contrastive += (
                    torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item()
        acc += (pred_class == label_class)
        loss = loss_function(pred_logits, label)
        losses += loss.item()
    losses = losses / (label.shape[0] * len(data_test.dataset))
    acc = acc / (len(data_test.dataset))
    acc_contrastive = acc_contrastive / (label.shape[0] * len(data_test.dataset))
    print('Test epoch  {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch, losses, acc,
                                                                                       acc_contrastive))

    if wandb is not None:
        wdb.log({"validation loss": losses, 'validation epoch': epoch, "validation classification accuracy": acc,
                 "validation contrastive accuracy": acc_contrastive})

    return losses, acc, acc_contrastive


def run_duo(args):
    # wandb init
    if args.wandb is not None:
        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'

        os.environ["WANDB_MODE"] = "offline"
        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")

        wdb.init(project="contrastive_classification", dir='./wandb_run', name=args.wandb)

    # load data
    data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir,
                                                                base_dir_val=args.dataset_val_dir,
                                                                base_dir_test=args.dataset_test_dir,
                                                                batch_size=args.batch_size,
                                                                ref_dir=args.dataset_ref_dir,
                                                                positive_prop=args.positive_prop, sampler=args.sampler)

    # load model
    model = Classification_model_duo_contrastive(model=args.model, n_class=2)
    model.double()
    # load weight
    if args.pretrain_path is not None:
        print('Model weight loaded')
        load_model(model, args.pretrain_path)
    # move parameters to GPU
    if torch.cuda.is_available():
        print('Model loaded on GPU')
        model = model.cuda()

    # init accumulators
    best_loss = 100
    train_acc = []
    train_loss = []
    val_acc = []
    val_cont_acc = []
    val_loss = []
    # init training
    loss_function = nn.CrossEntropyLoss()
    if args.opti == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else :
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    # train model
    for e in range(args.epoches):
        loss, acc = train_duo(model, data_train, optimizer, loss_function, e, args.wandb)
        train_loss.append(loss)
        train_acc.append(acc)
        if e % args.eval_inter == 0:
            loss, acc, acc_contrastive = val_duo(model, data_val_batch, loss_function, e, args.wandb)
            val_loss.append(loss)
            val_acc.append(acc)
            val_cont_acc.append(acc_contrastive)
            if loss < best_loss:
                save_model(model, args.save_path)
                best_loss = loss
        if e % args.test_inter == 0 and args.dataset_test_dir is not None:
            loss, acc, acc_contrastive = val_duo(model, data_test_batch, loss_function, e, args.wandb)
            val_loss.append(loss)
            val_acc.append(acc)
            val_cont_acc.append(acc_contrastive)
    # plot and save training figs
    if args.wandb is None:
        plt.clf()
        plt.subplot(2, 1, 1)
        plt.plot(train_acc, label='train cont acc')
        plt.plot(val_cont_acc, label='val cont acc')
        plt.plot(val_acc, label='val classification acc')
        plt.title('Train and validation accuracy')
        plt.xlabel('epoch')
        plt.ylabel('accuracy')
        plt.legend(loc="upper left")
        plt.ylim(0, 1.05)
        plt.tight_layout()

        plt.subplot(2, 1, 2)
        plt.plot(train_loss, label='train')
        plt.plot(val_loss, label='val')
        plt.title('Train and validation loss')
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.legend(loc="upper left")
        plt.tight_layout()

        plt.show()
        plt.savefig(args.base_out+'_training_plot.png')

    # load and evaluate best model
    load_model(model, args.save_path)
    if args.dataset_test_dir is not None :
        make_prediction_duo(model, data_test_batch,args.base_out+'_confusion_matrix_test.png',
                            args.base_out+'confidence_matrix_.png')

    make_prediction_duo(model, data_val_batch,args.base_out+'_confusion_matrix_val.png',
                        args.base_out+'_confidence_matrix_val.png')

    if args.wandb is not None:
        wdb.finish()


def make_prediction_duo(model, data, f_name, f_name2):
    for imaer, imana, img_ref, label in data:
        n_class = label.shape[1]
        break
    confidence_pred_list = [[] for i in range(n_class)]
    y_pred = []
    y_true = []
    soft_max = nn.Softmax(dim=1)
    # iterate over test data
    for imaer, imana, img_ref, label in data:
        imaer = imaer.transpose(0, 1)
        imana = imana.transpose(0, 1)
        img_ref = img_ref.transpose(0, 1)
        label = label.transpose(0, 1)
        label = label.squeeze()
        label = label.long()
        specie = torch.argmin(label)

        if torch.cuda.is_available():
            imaer = imaer.cuda()
            imana = imana.cuda()
            img_ref = img_ref.cuda()
            label = label.cuda()
        output = model(imaer, imana, img_ref)
        confidence = soft_max(output)
        confidence_pred_list[specie].append(confidence[:, 0].data.cpu().numpy())
        # Mono class output (only most postive paire)
        output = torch.argmax(output[:, 0])
        label = torch.argmin(label)
        y_pred.append(output.tolist())
        y_true.append(label.tolist())  # Save Truth
    # constant for classes

    # Build confusion matrix
    classes = data.dataset.classes
    cf_matrix = confusion_matrix(y_true, y_pred)
    confidence_matrix = np.zeros((n_class, n_class))
    for i in range(n_class):
        confidence_matrix[i] = np.mean(confidence_pred_list[i], axis=0)

    df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
                         columns=[i for i in classes])
    print('Saving Confusion Matrix')
    plt.clf()
    plt.figure(figsize=(14, 9))
    sn.heatmap(df_cm, annot=cf_matrix)
    plt.savefig(f_name)

    df_cm = pd.DataFrame(confidence_matrix, index=[i for i in classes],
                         columns=[i for i in classes])
    print('Saving Confidence Matrix')
    plt.clf()
    plt.figure(figsize=(14, 9))
    sn.heatmap(df_cm, annot=confidence_matrix)
    plt.savefig(f_name2)


def save_model(model, path):
    print('Model saved')
    torch.save(model.state_dict(), path)


def load_model(model, path):
    model.load_state_dict(torch.load(path, weights_only=True))


if __name__ == '__main__':
    args = load_args_contrastive()
    print(args)
    run_duo(args)