import os

import numpy as np
import tensorflow
from dlomix.data import RetentionTimeDataset
from dlomix.eval import TimeDeltaMetric
from dlomix.models import PrositRetentionTimePredictor
from matplotlib import pyplot as plt
from sklearn.metrics import r2_score
from keras import backend as K


def save_reg(pred, true, name):
    r2 = round(r2_score(true, pred), 4)
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(true, pred, 'y,')
    ax.text(120, 20, 'R² = ' + str(r2), fontsize=12)
    ax.set_xlabel("True")
    ax.set_ylabel("Pred")
    ax.set_xlim([-50, 200])
    ax.set_ylim([-50, 200])
    plt.savefig(name)
    plt.clf()


def save_evol(pred_prev, pred, true, name):
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(pred - true, pred_prev - true, 'y,')
    ax.set_xlabel("Current error")
    ax.set_ylabel("Previous error")
    ax.set_xlim([-50, 50])
    ax.set_ylim([-50, 50])
    plt.savefig(name)
    plt.clf()


class GradientCallback(tensorflow.keras.callbacks.Callback):
    console = True

    def on_epoch_end(self, epoch, logs=None, evol=True, reg=True):
        with tensorflow.GradientTape() as tape:
            for f, y in rtdata.train_data:
                features, y_true = f,y
                break
            y_pred = self.model(features)  # forward-propagation
            loss = self.model.compiled_loss(y_true=y_true, y_pred=y_pred)  # calculate loss
            gradients = tape.gradient(loss, self.model.trainable_weights)
        for weights, grads in zip(self.model.trainable_weights, gradients):
            tensorflow.summary.histogram(
                weights.name.replace(':', '_') + '_grads', data=grads, step=epoch, buckets=100)
        preds = self.model.predict(test_rtdata.test_data)
        if reg :
            save_reg(preds.flatten(), test_targets, 'fig/unstability/reg_epoch_'+str(epoch))
        if evol :
            if epoch >0 :
                pred_prev = np.load('temp/mem_pred.npy')
                save_evol(pred_prev, preds.flatten(), test_targets, 'fig/evol/reg_epoch_'+str(epoch))
        np.save('temp/mem_pred.npy', preds.flatten())


def lr_warmup_cosine_decay(global_step,
                           warmup_steps,
                           hold=0,
                           total_steps=0,
                           start_lr=0.0,
                           target_lr=1e-3):
    # Cosine decay
    learning_rate = 0.5 * target_lr * (
            1 + np.cos(np.pi * (global_step - warmup_steps - hold) / float(total_steps - warmup_steps - hold)))

    # Target LR * progress of warmup (=1 at the final warmup step)
    warmup_lr = target_lr * (global_step / warmup_steps)

    # Choose between `warmup_lr`, `target_lr` and `learning_rate` based on whether `global_step < warmup_steps` and we're still holding.
    # i.e. warm up if we're still warming up and use cosine decayed lr otherwise
    if hold > 0:
        learning_rate = np.where(global_step > warmup_steps + hold,
                                 learning_rate, target_lr)

    learning_rate = np.where(global_step < warmup_steps, warmup_lr, learning_rate)
    return learning_rate


class WarmupCosineDecay(tensorflow.keras.callbacks.Callback):
    def __init__(self, total_steps=0, warmup_steps=0, start_lr=0.0, target_lr=1e-3, hold=0):
        super(WarmupCosineDecay, self).__init__()
        self.start_lr = start_lr
        self.hold = hold
        self.total_steps = total_steps
        self.global_step = 0
        self.target_lr = target_lr
        self.warmup_steps = warmup_steps
        self.lrs = []

    def on_batch_end(self, batch, logs=None):
        self.global_step = self.global_step + 1
        lr = model.optimizer.lr.numpy()
        self.lrs.append(lr)

    def on_batch_begin(self, batch, logs=None):
        lr = lr_warmup_cosine_decay(global_step=self.global_step,
                                    total_steps=self.total_steps,
                                    warmup_steps=self.warmup_steps,
                                    start_lr=self.start_lr,
                                    target_lr=self.target_lr,
                                    hold=self.hold)
        K.set_value(self.model.optimizer.lr, lr)


if __name__ == '__main__':
    try:
        os.mkdir("./metrics_lr")
    except:
        pass

    try:
        os.mkdir("./logs_lr")
    except:
        pass

    BATCH_SIZE = 1024
    rtdata = RetentionTimeDataset(data_source='database/data_train.csv', sequence_col='sequence',
                                  seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False)
    test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
                                       seq_length=30, batch_size=BATCH_SIZE, test=True)
    test_targets = test_rtdata.get_split_targets(split="test")
    np.save('results/pred_prosit_ori/target.npy', test_targets)
    model = PrositRetentionTimePredictor(seq_length=30)
    model.compile(optimizer='adam',
                  loss='mse',
                  metrics=['mean_absolute_error', TimeDeltaMetric()])
    file_writer = tensorflow.summary.create_file_writer("./metrics_prosit")
    file_writer.set_as_default()

    # write_grads has been removed
    gradient_cb = GradientCallback()
    # tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir="./logs_prosit")
    # lr_callback = WarmupCosineDecay(total_steps=100, warmup_steps=10, start_lr=0.0, target_lr=1e-3, hold=5)
    model.fit(rtdata.train_data, epochs=100, batch_size=BATCH_SIZE, validation_data=rtdata.val_data,
              callbacks=[gradient_cb])