# coding = utf-8
import copy
from collections import Iterable

import numpy as np
import networkx as nx
import pandas as pd
from networkx.generators.degree_seq import _to_stublist
from cdlib import algorithms
import random
float_epsilon = np.finfo(float).eps




def powerlaw(nb_nodes, nb_edges, exponent=2, tries=100, min_deg=1):
    """
    Return a degree distribution that fit the power law and specified number of edges and vertices.
    Parameters
    ----------
    nb_nodes : int
    nb_edges : int
    exponent : int
    tries : int
    min_deg : int

    Returns
    -------
    np.ndarray
        degree sequence
    """
    nb_stubs = nb_edges * 2
    # Draw a first time a powerlaw degree sequence
    degs = np.round(nx.utils.powerlaw_sequence(nb_nodes, exponent=exponent))

    degs = degs[degs >= min_deg]
    # Compute de degree sum
    sum_deg = degs.sum()

    for _ in range(tries):
        # If the sum of the degree sequence is equal to the number of stubs, then it's good
        if sum_deg == nb_stubs:
            return degs
        # Draw a a new powerlaw degree sequence
        new_draw = np.round(nx.utils.powerlaw_sequence(nb_nodes, exponent=exponent))
        new_draw = new_draw[new_draw >= min_deg]
        new_sum_deg = new_draw.sum()

        # If the new degree sequence is closer to the objective than the previously draw sequence
        if abs(nb_stubs - new_sum_deg) < abs(nb_stubs - sum_deg):
            degs = new_draw
            sum_deg = new_sum_deg

    # Once the final draw is executed and the sequence degree sum is not equal to number of stubs expected
    if not sum_deg == nb_stubs:
        # We randomly pick sequence degrees and increment (or decrement) their values
        diff = abs(sum_deg - nb_stubs)
        signe = -1 if (nb_stubs - sum_deg) < 0 else 1
        indexes = np.random.choice(np.arange(len(degs)), int(diff))
        for ind in indexes:
            degs[ind] = degs[ind] + signe

    return degs.astype(int)


def get_countries_coords():
    """
    Return the coordinates of each country in the world.
    Returns
    -------
    np.ndarray
        coordinates
    """
    try:
        import geopandas as gpd
    except:
        raise ImportError("Geopandas is not installed !")
    gdf = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))

    return np.asarray(gdf.centroid.apply(lambda x: [x.x, x.y]).values.tolist())


def _conf_model(degree_seq):
    stubs_list = _to_stublist(degree_seq)
    random.shuffle(stubs_list)
    register = set()
    edges = []
    hash_func = lambda x, y: "_".join(sorted([str(x), str(y)]))
    tries = 0
    while len(stubs_list) > 0 and tries < 100:
        to_del = set([])
        for i in range(0, len(stubs_list) - 2, 2):
            u, v = stubs_list[i], stubs_list[i + 1]
            hash_ = hash_func(u, v)
            if hash_ in register:
                continue
            else:
                register.add(hash_)
                edges.append([u, v])
                to_del.add(i)
                to_del.add(i + 1)
        stubs_list = [stubs_list[i] for i in range(len(stubs_list)) if not i in to_del]
        random.shuffle(stubs_list)
        tries += 1
    G = nx.from_edgelist(edges)
    return G


def powerlaw_graph(nb_nodes, nb_edges, exponent=2, tries=1000, min_deg=0):
    """
    Generate a graph with a defined number of vertices, edges, and a degree distribution that fit the power law.
    Using the Molloy-Reed algorithm to
    Parameters
    ----------
    nb_nodes : int
    nb_edges : int
    exponent : int
    tries : int
    min_deg : int

    Returns
    -------
    nx.Graph
            generated graph
    """
    G = _conf_model(powerlaw(nb_nodes, nb_edges, exponent, tries, min_deg).astype(int))
    tries_ = 0
    while len(G) != nb_nodes and tries_ <tries:
        G = _conf_model(powerlaw(nb_nodes, nb_edges, exponent, tries, min_deg).astype(int))
        tries_ += 1
    if len(G) != nb_nodes:
        print(nb_nodes,nb_edges,exponent)
        raise Exception("Cant compute configuration model based on parameters")

    if G.size() != nb_edges:
        diff = abs(G.size() - nb_edges)
        signe = 1 if G.size() - nb_edges < 0 else -1
        if signe:
            for n in list(G.nodes()):
                if G.size() == nb_edges:
                    break
                for n2 in list(G.nodes()):
                    if not G.has_edge(n, n2): G.add_edge(n, n2)
                    if G.size() == nb_edges:
                        break
        else:
            edges_ = list(G.edges())
            random.shuffle(edges_)
            i = diff
            for ed in edges_:
                u, v = ed[0], ed[1]
                if G.degree(u) > 1 and G.degree(v) > 1:
                    G.remove_edge(u, v)
                    i -= 1
    return G


def spatial_graph(nb_nodes, nb_edges, coords="country", dist_func=lambda a, b: np.linalg.norm(a - b) ** 2,
                  self_link=False, weighted=False):
    """
    Generate a spatial graph with a specific number of vertices and edges
    Parameters
    ----------
    nb_nodes : int
    nb_edges : int
    coords : array of shape (n,2) or str
        if str, possible choice are "random" or "country"
    dist_func : callable
    self_link : bool

    Returns
    -------
    nx.Graph
        generated graph
    """
    if nb_nodes > nb_edges:
        raise ValueError(
            "To obtain a specific nb of nodes, the number of edges must be equal or superior to the number of nodes !")
    if coords and isinstance(coords, Iterable) and not isinstance(coords, str):
        if len(coords) != nb_nodes:
            raise ValueError("number of nodes must match the size of the coords dict")
    elif coords == "random":
        coords = np.random.random(nb_nodes * 2).reshape(nb_nodes, 2)
        coords[:, 0] = (coords[:, 0] * 360) - 180
        coords[:, 1] = (coords[:, 1] * 180) - 90
    else:
        coords = get_countries_coords()
        if nb_nodes > len(coords):
            raise ValueError(
                "Too many nodes for coords = \"country\". Change nb_nodes value or change coords to 'random' or your own list of coords")
        coords_index = np.random.choice(np.arange(len(coords)), nb_nodes)
        coords = coords[coords_index]
    data = []
    float_epsilon = np.finfo(float).eps
    for i in range(nb_nodes):
        for j in range(nb_nodes):
            if i == j and not self_link:
                continue
            data.append([i, j, 1 / (float_epsilon+(dist_func(coords[i], coords[j])))])
    df = pd.DataFrame(data, columns="src tar weight".split()).astype({"src": int, "tar": int})
    df["hash"] = df.apply(lambda x: "_".join(sorted([str(int(x.src)), str(int(x.tar))])), axis=1)
    df = df.drop_duplicates(subset="hash")
    df["weight"] = df.weight / df.weight.sum()

    iter_ = 0
    f = False
    best_G = None
    new_df = None
    while iter_ < 50 and f == False:
        new_df = df.sample(n=nb_edges, weights="weight")
        if weighted:
            G = nx.from_pandas_edgelist(new_df, source="src", target="tar", edge_attr="weight")
        else:
            G = nx.from_pandas_edgelist(new_df, source="src", target="tar")
        f = (len(G) == nb_nodes and G.size() == nb_edges)
        if not best_G == None:
            if abs(len(best_G) - nb_nodes) > abs(len(G) - nb_nodes):
                best_G = G.copy()
        else:
            best_G = G.copy()
        iter_ += 1
    G = best_G.copy()

    if not len(G) == nb_nodes:
        diff = abs(len(G) - nb_nodes)
        df_deg = pd.DataFrame(nx.degree(G), columns="node degree".split())
        df_deg = df_deg[df_deg.degree > 2].sort_values(by="degree", ascending=False)
        i = 1
        while 1:
            new_df = df_deg.head(i)
            if ((new_df.degree) - 1).sum() > diff:
                break
            i += 1
        df_deg = df_deg.head(i)
        edges_ = []
        for node in df_deg.node.values:
            node_edges = list(G.edges(node))
            random.shuffle(node_edges)
            edges_.extend([[ed[0], ed[1]] for ed in node_edges[1:]])
        idx = np.random.choice(np.arange(len(edges_)), size=diff, replace=False)
        edges_ = np.array(edges_)[idx]
        G.remove_edges_from(edges_)

        missing_nodes = list(set(range(nb_nodes)) - set(list(G.nodes())))
        for node in missing_nodes:
            new_df = df[(df.src == node) | (df.tar == node)].sample(n=1, weights="weight")
            edges = new_df["src tar".split()].values
            G.add_edges_from(edges)

    for n in list(G.nodes()):
        G.nodes[n]["pos"] = coords[n]
    return G


def ER_graph(nb_nodes, nb_edges):
    """
    Generate a random graph with a specific nb of nodes and edges.
    Parameters
    ----------
    nb_nodes : int
    nb_edges : int

    Returns
    -------
    nx.Graph
        generated graph
    """
    return nx.dense_gnm_random_graph(nb_nodes, nb_edges)


def stochastic_block_model_graph(nb_nodes, nb_edges, nb_com, percentage_edge_betw, verbose=False):
    """
    Generate a stochastic block model graph with defined number of vertices and edges.
    Parameters
    ----------
    nb_nodes : int
    nb_edges : int
    nb_com : int
    percentage_edge_betw : float
    verbose : bool

    Returns
    -------
    nx.Graph
        generated graph
    """

    if nb_nodes % nb_com != 0:
        raise ValueError("Modulo between the number of nodes and community must be equal to 0")

    edge_max = (1 / nb_com) * ((nb_nodes * (nb_nodes - 1)) / 2)
    if nb_edges > edge_max:
        raise ValueError("nb_edges must be inferior to {0}".format(edge_max))

    def nb_of_pair(N):
        return (N*(N-1))/2

    G = nx.planted_partition_graph(nb_com, nb_nodes // nb_com, 1, 1)
    block_assign = nx.get_node_attributes(G, "block")
    b_assign_array = np.asarray(list(nx.get_node_attributes(G,"block").values()))
    if verbose:
        print(G.size())


    u_in = sum([nb_of_pair((b_assign_array==b).sum()) for b in range(nb_com)])
    u_out = nb_of_pair(len(G)) - u_in
    l_out = nb_edges*percentage_edge_betw
    p_out = l_out/u_out
    l_in = nb_edges - l_out

    p_in = l_in / u_in
    if verbose:
        print("u_in",u_in)
        print("u_out", u_out)
        print("l_out",l_out)
        print("l_in", l_in)
        print("p_in",p_in)
        print("p_out", p_out)


    inter_edges, intra_edges = get_inter_intra_edges(G,G.is_directed())
    inter_edges = np.asarray(inter_edges)
    intra_edges = np.asarray(intra_edges)
    inter_N, intra_N = len(inter_edges), len(intra_edges)
    probs_inter = np.ones(inter_N) * p_out
    probs_intra = np.ones(intra_N) * p_in

    all_edges = np.concatenate((inter_edges, intra_edges))
    all_probs = np.concatenate((probs_inter, probs_intra))
    del probs_inter
    del probs_intra
    all_probs /= all_probs.sum()

    if verbose:
        print(inter_N, intra_N)

    final_edges = []
    index_selected_pairs = np.random.choice(np.arange(len(all_edges)), nb_edges, p=all_probs, replace=False)
    final_edges.extend(all_edges[index_selected_pairs])

    if verbose:
        print(len(final_edges))

    G2 = nx.from_edgelist(final_edges)
    if len(G2) != nb_nodes:
        equilibrate(G2, nb_nodes, percentage_edge_betw, 1-percentage_edge_betw, inter_edges, intra_edges, block_assign)

    for n in list(G2.nodes()):
        G2.nodes[n]["block"] = block_assign[n]
    return G2


def equilibrate(G, nb_nodes, percentage_edge_betw, percentage_edge_within, inter_edges, intra_edges, block_assign):
    """
    Sometimes the generated graph from the stochastic block model have some missing nodes due to the sampling method.
    This function fix this issue.
    Parameters
    ----------
    G
    nb_nodes
    percentage_edge_betw
    percentage_edge_within
    inter_edges
    intra_edges
    block_assign

    Returns
    -------

    """
    diff = nb_nodes - len(G)
    nodes_missing = list(set(list(range(nb_nodes))) - set(list(G.nodes())))
    nb_edge_inter = int(np.round(len(nodes_missing) * percentage_edge_betw))
    nodes_inter = np.random.choice(nodes_missing, nb_edge_inter, replace=False)
    nodes_intra = list(set(nodes_missing) - set(nodes_inter))

    def draw_(array, register, hash_func=lambda x, y: "_".join(sorted([str(x), str(y)]))):
        tries = 0
        while tries < 1000:
            index_array = np.random.choice(np.arange(len(array)), 1)
            res = array[index_array]
            res = res[0]
            hash_ = hash_func(res[0], res[1])
            if not hash_ in register:
                register.add(hash_)
                return index_array
            tries += 1
        raise Exception("Error ! (TODO)")

    # Draw new edges
    new_edges = []
    register = set([])
    for node in nodes_inter:
        mask_inter = np.isin(inter_edges, node).sum(axis=1).astype(bool)
        index_inter = draw_(inter_edges[mask_inter], register)
        new_edges.extend(inter_edges[mask_inter][index_inter])

    for node in nodes_intra:
        mask_intra = np.isin(intra_edges, node).sum(axis=1).astype(bool)
        index_intra = draw_(intra_edges[mask_intra], register)
        new_edges.extend(intra_edges[mask_intra][index_intra])

    # Draw existing edge to delete
    edge_to_delete = []
    inter_edges_c, intra_edges_c = [], []
    for ed in list(G.edges()):
        if G.degree(ed[0]) > 1 and G.degree(ed[1]) > 1:  # We don't want the nodes edge to be removed with the
            # disappearance of the edge
            if block_assign[ed[0]] != block_assign[ed[1]]:
                inter_edges_c.append([ed[0], ed[1]])
            else:
                intra_edges_c.append([ed[0], ed[1]])
    index_inter = np.random.choice(np.arange(len(inter_edges_c)), int(np.round(diff * percentage_edge_betw)),
                                   replace=False)
    index_intra = np.random.choice(np.arange(len(intra_edges_c)), int(np.round(diff * percentage_edge_within)),
                                   replace=False)

    edge_to_delete.extend(np.asarray(inter_edges_c)[index_inter])
    edge_to_delete.extend(np.asarray(intra_edges_c)[index_intra])

    for ed in edge_to_delete:
        G.remove_edge(*ed)
    for ed in new_edges:
        G.add_edge(*ed)
    return G


def add_partitions_G(G, nb_com=5):
    blocks = np.random.choice(np.arange(nb_com), len(G), p=[1 / nb_com] * nb_com)
    for ix, node in enumerate(list(G.nodes())):
        try:
            G.nodes[node]["block"] = blocks[ix]
        except KeyError:
            G.nodes[node]["block"] = 0
    return G


def get_inter_intra_edges(G, directed=False):
    block_assign = nx.get_node_attributes(G, "block")
    assert (len(block_assign) ==len(G))

    inter_edges, intra_edges = [], []
    register = set([])
    for n1 in list(G.nodes()):
        for n2 in list(G.nodes()):
            if directed:
                hash_ = "_".join([str(n1), str(n2)])
            else:
                hash_ = "_".join(sorted([str(n1), str(n2)]))
            if (n1 == n2) or (hash_ in register):
                continue
            b1, b2 = block_assign[n1], block_assign[n2]
            if b1 != b2:
                inter_edges.append([n1, n2])
            else:
                intra_edges.append([n1, n2])
            register.add(hash_)
    return inter_edges,intra_edges

def mixed_model_spat_sbm(nb_nodes, nb_edges, nb_com, alpha, percentage_edge_betw=0.01,
                         dist_func=lambda p1, p2: np.linalg.norm(p1 - p2) ** 2):
    G = spatial_graph(nb_nodes, nb_edges, coords="random", dist_func=dist_func)
    G = add_partitions_G(G, nb_com)
    nb_edges = G.size()
    float_epsilon = np.finfo(float).eps
    def nb_of_pair(N):
        return (N * (N - 1)) / 2

    block_assign = nx.get_node_attributes(G, "block")
    assert (len(block_assign) == len(G))
    b_assign_array = np.asarray(list(block_assign.values()))

    u_in = sum([nb_of_pair((b_assign_array == b).sum()) for b in range(np.max(b_assign_array))])
    u_out = nb_of_pair(len(G)) - u_in
    l_out = nb_edges * percentage_edge_betw
    p_out = l_out / u_out
    l_in = nb_edges - l_out
    p_in = l_in / u_in


    inter_edges, intra_edges = get_inter_intra_edges(G,G.is_directed())
    inter_N, intra_N = len(inter_edges), len(intra_edges)
    probs_sbm_inter = np.ones(inter_N) * p_out
    probs_sbm_intra = np.ones(intra_N) * p_in

    all_edges = np.concatenate((inter_edges, intra_edges))
    all_probs_sbm = np.concatenate((probs_sbm_inter, probs_sbm_intra))
    all_probs_sbm /= all_probs_sbm.sum()

    pos = nx.get_node_attributes(G,"pos")
    all_probs_spa = np.asarray([1 / (float_epsilon + dist_func(pos[edge[0]], pos[edge[1]])) for edge in all_edges])
    all_probs_spa /= all_probs_spa.sum()


    #all_probs = alpha * (all_probs_sbm) + (1 - alpha) * all_probs_spa
    nb_edges_sbm,nb_edges_spa = round(alpha*nb_edges),round((1-alpha)*nb_edges)

    final_edges = []
    index_selected_pairs_sbm = np.random.choice(np.arange(len(all_edges)), nb_edges_sbm, p=all_probs_sbm, replace=False)
    final_edges.extend(all_edges[index_selected_pairs_sbm])

    all_probs_spa[index_selected_pairs_sbm] = all_probs_spa.min()
    all_probs_spa/= all_probs_spa.sum()

    index_selected_pairs_spa = np.random.choice(np.arange(len(all_edges)), nb_edges_spa, p=all_probs_spa, replace=False)
    final_edges.extend(all_edges[index_selected_pairs_spa])

    G2 = nx.from_edgelist(final_edges)

    for n in list(G2.nodes()):
        G2.nodes[n]["block"] = block_assign[n]
        G2.nodes[n]["pos"] = G.nodes[n]["pos"]

    return G2#,all_probs_sbm,all_probs_spa



def get_sbm_probs(G, percentage_edge_betw, verbose=False):
    hash_func = lambda x: "_".join(sorted([str(x[0]), str(x[1])]))
    def nb_of_pair(N):
        return (N*(N-1))/2

    block_assign = nx.get_node_attributes(G, "block")
    nb_com = len(set(block_assign.values()))
    nb_nodes=len(G)
    nb_edges = G.size()
    b_assign_array = np.asarray(list(nx.get_node_attributes(G,"block").values()))



    u_in = sum([nb_of_pair((b_assign_array==b).sum()) for b in range(nb_com)])
    u_out = nb_of_pair(len(G)) - u_in
    l_out = nb_edges*percentage_edge_betw
    p_out = l_out/u_out
    l_in = nb_edges - l_out

    p_in = l_in / u_in

    inter_edges, intra_edges = get_inter_intra_edges(G,G.is_directed())
    inter_edges = np.asarray(inter_edges)
    intra_edges = np.asarray(intra_edges)
    inter_N, intra_N = len(inter_edges), len(intra_edges)
    probs_inter = np.ones(inter_N) * p_out
    probs_intra = np.ones(intra_N) * p_in

    all_edges = np.concatenate((inter_edges, intra_edges))
    all_probs = np.concatenate((probs_inter, probs_intra))
    del probs_inter
    del probs_intra
    return all_edges,all_probs


def get_spat_probs(G,dist = lambda a,b : np.linalg.norm(a-b)**2):
    hash_func = lambda x: "_".join(sorted([str(x[0]), str(x[1])]))
    pos = nx.get_node_attributes(G, "pos")
    spat_model = lambda u, v: 1 / (float_epsilon + dist(pos[u], pos[v]))
    register = set([])
    edges, probs = [], []
    for n1 in list(G.nodes()):
        for n2 in list(G.nodes()):
            if n1 != n2 and hash_func((n1, n2)) not in register:
                edges.append([n1, n2])
                probs.append(spat_model(n1, n2))
                register.add(hash_func((n1, n2)))

    return edges, probs