import os
import datetime

from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb
from dlomix.data import RetentionTimeDataset
from dlomix.eval import TimeDeltaMetric
from dlomix.models import PrositRetentionTimePredictor, RetentionTimePredictor
from dlomix.reports import RetentionTimeReport
import tensorflow


def save_reg(pred, true, name):
    coef = np.polyfit(pred, true, 1)
    poly1d_fn = np.poly1d(coef)
    r2 = round(r2_score(pred, true), 4)
    plt.plot(pred, true, 'y,', pred, poly1d_fn(pred), '--k')
    plt.text(120, 20, 'R² = ' + str(r2), fontsize=12)
    plt.savefig(name)
    plt.clf()


def track_train(model, epoch, test_rtdata, rtdata):
    BATCH_SIZE = 256
    test_targets = test_rtdata.get_split_targets(split="test")
    train_target = rtdata.get_split_targets(split="train")
    loss = tensorflow.keras.losses.MeanSquaredError()
    metric = TimeDeltaMetric()
    optimizer = tensorflow.keras.optimizers.Adam()

    os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'

    os.environ["WANDB_MODE"] = "offline"
    os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")

    wandb.init(project="Prosit ori full dataset", dir='./wandb_run', name='prosit ori')

    for e in range(epoch):
        for step, (X_batch, y_batch) in enumerate(rtdata.train_data):
            with tensorflow.GradientTape() as tape:
                predictions = model(X_batch, training=True)
                l = loss(predictions, y_batch)
            grads = tape.gradient(l, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))
            wandb.log({'grads': grads})
        predictions = model.predict(test_rtdata.test_data)
        save_reg(predictions.flatten(), test_targets, 'fig/unstability/reg_epoch_' + str(e))

    wandb.finish()


def train_step(model, optimizer, x_train, y_train, step):
    with tensorflow.GradientTape() as tape:
        predictions = model(x_train, training=True)
        tape.watch(model.trainable_variables)
        loss = loss_object(y_train, predictions)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    for weights, grads in zip(model.trainable_weights, grads):
        tensorflow.summary.histogram(
            weights.name.replace(':', '_') + '_grads', data=grads, step=step)
    train_loss(loss)
    train_accuracy(y_train, predictions)


def test_step(model, x_test, y_test):
    predictions = model(x_test)
    loss = loss_object(y_test, predictions)

    test_loss(loss)
    test_accuracy(y_test, predictions)


def main():
    BATCH_SIZE = 256

    rtdata = RetentionTimeDataset(data_source='database/data_train.csv',
                                  seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False)
    test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
                                       seq_length=30, batch_size=32, test=True)
    test_targets = test_rtdata.get_split_targets(split="test")
    model = PrositRetentionTimePredictor(seq_length=30)

    model.compile(optimizer='adam',
                  loss='mse',
                  metrics=['mean_absolute_error', TimeDeltaMetric()])

    os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'

    os.environ["WANDB_MODE"] = "offline"
    os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")

    wandb.init(project="Prosit ori full dataset", dir='./wandb_run', name='prosit ori')

    history = model.fit(rtdata.train_data,
                        validation_data=rtdata.val_data,
                        epochs=100)

    wandb.finish()
    test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
                                       seq_length=30, batch_size=32, test=True)

    predictions = model.predict(test_rtdata.test_data)
    test_targets = test_rtdata.get_split_targets(split="test")

    report = RetentionTimeReport(output_path="./output", history=history)


def main_track():
    BATCH_SIZE = 256

    rtdata = RetentionTimeDataset(data_source='database/data_train.csv',
                                  seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False)
    test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
                                       seq_length=30, batch_size=BATCH_SIZE, test=True)
    test_targets = test_rtdata.get_split_targets(split="test")
    model = RetentionTimePredictor(seq_length=30)
    track_train(model, 100, test_rtdata, rtdata)


if __name__ == '__main__':
    # loss_object = tensorflow.keras.losses.MeanSquaredError()
    # optimizer = tensorflow.keras.optimizers.Adam()
    # train_loss = tensorflow.keras.metrics.Mean('train_loss', dtype=tensorflow.float32)
    # train_accuracy = tensorflow.keras.metrics.MeanAbsoluteError('train_accuracy')
    # test_loss = tensorflow.keras.metrics.Mean('test_loss', dtype=tensorflow.float32)
    # test_accuracy = tensorflow.keras.metrics.MeanAbsoluteError('test_accuracy')
    #
    # current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    # train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
    # test_log_dir = 'logs/gradient_tape/' + current_time + '/test'
    # train_summary_writer = tensorflow.summary.create_file_writer(train_log_dir)
    # test_summary_writer = tensorflow.summary.create_file_writer(test_log_dir)

    BATCH_SIZE = 256
    rtdata = RetentionTimeDataset(data_source='database/data_train.csv',
                                  seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False)
    test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
                                       seq_length=30, batch_size=32, test=True)
    test_targets = test_rtdata.get_split_targets(split="test")
    # model = RetentionTimePredictor(seq_length=30)
    #
    EPOCHS = 5

    for epoch in range(EPOCHS):
        for (x_train, y_train) in rtdata.train_data:
            print(x_train)
            break
        #     train_step(model, optimizer, x_train, y_train, epoch)
        # with train_summary_writer.as_default():
        #     tensorflow.summary.scalar('loss', train_loss.result(), step=epoch)
        #     tensorflow.summary.scalar('accuracy', train_accuracy.result(), step=epoch)
        #
        # for (x_test, y_test) in test_rtdata.test_data:
        #     test_step(model, x_test, y_test)
        # with test_summary_writer.as_default():
        #     tensorflow.summary.scalar('loss', test_loss.result(), step=epoch)
        #     tensorflow.summary.scalar('accuracy', test_accuracy.result(), step=epoch)
        #
        # template = 'Epoch {}, Loss: {}, Absolute Error: {}, Test Loss: {}, Test Absolute Error: {}'
        # print(template.format(epoch + 1,
        #                       train_loss.result(),
        #                       train_accuracy.result(),
        #                       test_loss.result(),
        #                       test_accuracy.result()))
        #
        # # Reset metrics every epoch
        # train_loss.reset_states()
        # test_loss.reset_states()
        # train_accuracy.reset_states()
        # test_accuracy.reset_states()