import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import seaborn as sns
import networkx as nx
import pandas as pd
import numpy as np
from glob import glob

from fa2 import ForceAtlas2


def get_force_atlas(weight_influence=0, scaling_ratio=3.0, gravity=5):
    """
    Return an instance of ForceAtlas with a specific configuration
    Parameters
    ----------
    weight_influence: float
        between 0 and 1 (default 0)
    scaling_ratio : float or int
        see fa2 documentation(default 3)
    gravity : float or int
        see fa2 documentation (default 5)

    Returns
    -------
    ForceAtlas2
        instance of ForceAtlas2
    """
    forceatlas2 = ForceAtlas2(
        # Behavior alternatives
        outboundAttractionDistribution=True,  # Dissuade hubs
        linLogMode=False,  # NOT IMPLEMENTED
        adjustSizes=False,  # Prevent overlap (NOT IMPLEMENTED)
        edgeWeightInfluence=weight_influence,

        # Performance
        jitterTolerance=1.0,  # Tolerance
        barnesHutOptimize=True,
        barnesHutTheta=1.2,
        multiThreaded=False,  # NOT IMPLEMENTED

        # Tuning
        scalingRatio=scaling_ratio,
        strongGravityMode=False,
        gravity=gravity,

        # Log
        verbose=False)
    return forceatlas2


def draw(G, labels_dict={}, iteration_force_atlase=2000, figsize=(40, 20), font_size=12, stroke_width=3,
         stroke_color="black", font_color="white", edge_cmap=plt.cm.viridis, weight=True):
    """
    Return a figure of a NetworkX graph
    Parameters
    ----------
    G : nx.Graph
        graph instance
    labels_dict: dict
        label for each node id
    iteration_force_atlase: int
        nb of iteration for the Force Atlas algorithm
    figsize: tuple
        figure size (matplotlib)
    font_size: int
        font size
    stroke_width : int
        text contour size
    stroke_color: str
        text contour color
    font_color : str
        text color
    edge_cmap: matplotlib.pyplot.cm
        Matplotlib Colormap instance used when edges are associated with a weight

    Returns
    -------
    AxesSubplot
        matplotlib canvas
    """
    plt.gcf()  # Clean previous figure associated with the 'plt' instance

    # Compute node position using the Force Atlas algorithm
    force_atlas = get_force_atlas()
    positions = force_atlas.forceatlas2_networkx_layout(G,
                                                        pos=None,
                                                        iterations=iteration_force_atlase)
    # Initialise the figure canvas
    fig, ax = plt.subplots(1, figsize=figsize)

    # Draw nodes
    nodes = nx.draw_networkx_nodes(G, positions, node_color='#999', ax=ax)
    edges = None
    # Draw edges
    if weight:
        weights_width = [G[u][v]['weight'] * 200 for u, v in list(G.edges())]
        colors = [G[u][v]['weight'] for u, v in list(G.edges())]
        edges = nx.draw_networkx_edges(G, positions, edge_color=colors, width=weights_width,
                                       edge_cmap=edge_cmap, ax=ax)
    else:
        edges = nx.draw_networkx_edges(G, positions, ax=ax, edge_color="#999")

    # Plot nodes label
    for node, pos in positions.items():
        if labels_dict:
            text = ax.text(pos[0], pos[1], labels_dict[node], color=font_color,
                           ha='center', va='center', size=font_size)
        else:
            text = ax.text(pos[0], pos[1], node, color=font_color,
                           ha='center', va='center', size=font_size)

        text.set_path_effects([path_effects.Stroke(linewidth=stroke_width, foreground=stroke_color),
                               path_effects.Normal()])  # effet de style
    # Plot colorbar
    if weight:
        sm = plt.cm.ScalarMappable(cmap=edge_cmap, norm=plt.Normalize(vmin=min(colors), vmax=max(colors)))
        sm.set_array([])
        fig.colorbar(sm)
    plt.axis("off")
    plt.tight_layout()
    return fig, ax


def average_degree(graph_dir, ext=".txt"):
    """
    Produce a figure that shows the average degree per number of edges in a graph dataset.
    Parameters
    ----------
    graph_dir: str
        graph dataset directory path
    ext : str
        extension of the graph file (must be edgelist format)

    Returns
    -------
        Figure, AxesSubplot
    """
    plt.gcf()
    fns = glob(graph_dir + "/*" + ext)
    data = []
    for fn in fns:
        df = pd.read_csv(fn, header=None, names="source target".split())
        G = nx.from_pandas_edgelist(df, create_using=nx.DiGraph())
        degree_values = np.asarray(list(G.degree()))[:, 1]
        data.append([len(list(G.edges())), degree_values.mean()])
    df = pd.DataFrame(data, columns="nb_edges avg_degree".split())
    fig, ax = plt.subplots(1, figsize=(10, 5))
    ax = sns.scatterplot(data=df, x="nb_edges", y="avg_degree", hue="nb_edges", legend=False, ax=ax)
    ax.set(xlabel="Number of edges", ylabel="Average Degree")
    return fig, ax