from scipy.sparse import data
import torch
import torch.nn as nn
import numpy as np
import scipy.sparse
import scipy.io
from sklearn.metrics import roc_auc_score , accuracy_score, f1_score, precision_score, recall_score
from datetime import datetime
import argparse
import time
from model import Dominant
from utils import load_anomaly_detection_dataset

class RUNNER:
    def __init__(self,adj,features,labels,dim=20,lr=5e-3,alpha=0.5,dropout=0.1,device='cuda',epochs=50):
        self.alpha = alpha
        self.epochs = epochs
        self.adj, self.attrs, self.label, self.adj_label = load_anomaly_detection_dataset(adj,features,labels)
        self.model = Dominant(feat_size = self.attrs.size(1), hidden_size = dim, dropout = dropout)
        if device == 'cuda':
            device = torch.device(device)
            self.adj = self.adj.to(device)
            self.adj_label = self.adj_label.to(device)
            self.attrs = self.attrs.to(device)
            self.model = self.model.cuda()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = lr)


    def loss_func(self,adj, A_hat, attrs, X_hat, alpha):
        # select only the nodes concerned by the split
        # X_hat = X_hat[idx] #-----
        # attrs = attrs[idx] #-----
        # Attribute reconstruction loss
        diff_attribute = torch.pow(X_hat - attrs, 2)
        attribute_reconstruction_errors = torch.sqrt(torch.sum(diff_attribute, 1))
        attribute_cost = torch.mean(attribute_reconstruction_errors)

        # remove the nodes non concerned by the split
        # mask = torch.zeros(A_hat.size(0), dtype=bool, device='cuda') #-------
        # mask[idx] = True #------
        # A_hat = A_hat[mask] #------
        # A_hat = A_hat[:,mask] #------
        # adj = adj[mask] #-------
        # adj = adj[:,mask] #-------

        # structure reconstruction loss
        diff_structure = torch.pow(A_hat - adj, 2)
        structure_reconstruction_errors = torch.sqrt(torch.sum(diff_structure, 1))
        structure_cost = torch.mean(structure_reconstruction_errors)
        cost =  alpha * attribute_reconstruction_errors + (1-alpha) * structure_reconstruction_errors
        return cost, structure_cost, attribute_cost

    def fit(self,outlier_percentage,invert=False):        
        train_losses = []
        best_losses = []
        start = time.time()
        for epoch in range(self.epochs):
            self.model.train()
            self.optimizer.zero_grad()
            A_hat, X_hat = self.model(self.attrs, self.adj)
            loss, struct_loss, feat_loss = self.loss_func(self.adj_label, A_hat, self.attrs, X_hat, self.alpha)
            l = torch.mean(loss)
            train_losses.append(l.item())
            if(len(best_losses) == 0):
                best_losses.append(train_losses[0])
            elif (best_losses[-1] > train_losses[-1]):
                best_losses.append(train_losses[-1])
            else:
                best_losses.append(best_losses[-1])

            l.backward()
            self.optimizer.step()        
            print("Epoch:", '%04d' % (epoch), "train_loss=", "{:.5f}".format(l.item()), "str_loss=", "{:.5f}".format(struct_loss.item()), "feat_loss=", "{:.5f}".format(feat_loss.item()))
            score = loss.detach().cpu().numpy()
            sorted_indices = np.argsort(score)
            sorted_score = np.array(score)[sorted_indices]
            sorted_label = np.array(self.label)[sorted_indices]
            if invert:
                sorted_score[:int(len(sorted_score)*(1-outlier_percentage))] = 1
                sorted_score[int(len(sorted_score)*outlier_percentage):] = 0
            else:
                sorted_score[:int(len(sorted_score)*(1-outlier_percentage))] = 0
                sorted_score[int(len(sorted_score)*outlier_percentage):] = 1

            acc = accuracy_score(sorted_label,sorted_score)
            f1 = f1_score(sorted_label,sorted_score)
            rec = recall_score(sorted_label,sorted_score)
            pre = precision_score(sorted_label,sorted_score)
            auc =  roc_auc_score(sorted_label, sorted_score)
            print('Acc',acc,'f1',f1,'Rec',rec,'pre',pre,'Auc',auc)
            
            # if(self.idx_val.shape[0] >0):
            #     self.model.eval()
            #     A_hat, X_hat = self.model(self.attrs, self.adj)
            #     loss, struct_loss, feat_loss = self.loss_func(self.adj_label, A_hat, self.attrs, X_hat, self.alpha,self.idx_val)
            #     score = loss.detach().cpu().numpy()
            #     sorted_indices = np.argsort(score)
            #     sorted_score = np.array(score)[sorted_indices]
            #     sorted_label = np.array(self.label)[sorted_indices]
            #     sorted_score[:int(len(sorted_score)/2)] = 1
            #     sorted_score[int(len(sorted_score)/2):] = 0
            #     print("VALIDATION:",'Acc',accuracy_score(sorted_label,sorted_score),'f1',f1_score(sorted_label,sorted_score),'Rec',recall_score(sorted_label,sorted_score),'pre',precision_score(sorted_label,sorted_score),'Auc', roc_auc_score(sorted_label, sorted_score))
            #     print('\n')

        return {'train':train_losses,'best':best_losses,'val':[]},acc,f1,rec,pre,auc,time.time() - start


    def predict(self,outlier_percentage,invert=False):
        self.model.eval()
        A_hat, X_hat = self.model(self.attrs, self.adj)
        loss, struct_loss, feat_loss = self.loss_func(self.adj_label, A_hat, self.attrs, X_hat, self.alpha)
        score = loss.detach().cpu().numpy()
        sorted_indices = np.argsort(score)
        sorted_score = np.array(score)[sorted_indices]
        sorted_label = np.array(self.label)[sorted_indices]
        if invert:
            sorted_score[:int(len(sorted_score)*(1-outlier_percentage))] = 1
            sorted_score[int(len(sorted_score)*outlier_percentage):] = 0
        else:
            sorted_score[:int(len(sorted_score)*(1-outlier_percentage))] = 0
            sorted_score[int(len(sorted_score)*outlier_percentage):] = 1
        print('\n')
        print('\n')
        acc = accuracy_score(sorted_label,sorted_score)
        f1 = f1_score(sorted_label,sorted_score)
        rec = recall_score(sorted_label,sorted_score)
        pre = precision_score(sorted_label,sorted_score)
        auc =  roc_auc_score(sorted_label, sorted_score)
        return np.mean(score),acc,f1,rec,pre,auc