Skip to content
Snippets Groups Projects
draw.py 2.91 KiB
Newer Older
Fize Jacques's avatar
Fize Jacques committed
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import networkx as nx
import pandas as pd
import numpy as np

from fa2 import ForceAtlas2


def get_force_atlas(weight_influence=0, scaling_ratio=3.0, gravity=5):
    forceatlas2 = ForceAtlas2(
        # Behavior alternatives
        outboundAttractionDistribution=True,  # Dissuade hubs
        linLogMode=False,  # NOT IMPLEMENTED
        adjustSizes=False,  # Prevent overlap (NOT IMPLEMENTED)
        edgeWeightInfluence=weight_influence,

        # Performance
        jitterTolerance=1.0,  # Tolerance
        barnesHutOptimize=True,
        barnesHutTheta=1.2,
        multiThreaded=False,  # NOT IMPLEMENTED

        # Tuning
        scalingRatio=scaling_ratio,
        strongGravityMode=False,
        gravity=gravity,

        # Log
        verbose=False)
    return forceatlas2


def draw(G, labels_dict={}, iteration_force_atlase=2000, figsize=(40, 20), font_size=12, stroke_width=3,
         stroke_color="black", font_color="white", edge_cmap=plt.cm.viridis, weight = True):
    """
    Return a figure of the current graph
    Parameters
    ----------
    G
    labels_dict
    iteration_force_atlase
    figsize
    font_size
    stroke_width
    stroke_color
    font_color
    edge_cmap

    Returns
    -------
    AxesSubplot
        matplotlib canvas
    """
    plt.gcf()  # Clean previous figure associated with the 'plt' instance

    # Compute node position using the Force Atlas algorithm
    force_atlas = get_force_atlas()
    positions = force_atlas.forceatlas2_networkx_layout(G,
                                                        pos=None,
                                                        iterations=iteration_force_atlase)
    # Initialise the figure canvas
    fig, ax = plt.subplots(1, figsize=figsize)

    # Draw nodes
    nx.draw_networkx_nodes(G, positions, node_color='#999', ax=ax)

    # Draw edges
    if weight:
        weights_width = [G[u][v]['weight'] * 200 for u, v in list(G.edges())]
        colors = [G[u][v]['weight'] for u, v in list(G.edges())]
        edges = nx.draw_networkx_edges(G, positions, edge_color=colors, width=weights_width,
                                       edge_cmap=edge_cmap, ax=ax)
    else:
        edges = nx.draw_networkx_edges(G, positions, ax=ax,edge_color="#999")

    # Plot nodes label
    for node, pos in positions.items():
        if labels_dict:
            text = ax.text(pos[0], pos[1], labels_dict[node], color=font_color,
                           ha='center', va='center', size=font_size)
        else:
            text = ax.text(pos[0], pos[1], node, color=font_color,
                           ha='center', va='center', size=font_size)

        text.set_path_effects([path_effects.Stroke(linewidth=stroke_width, foreground=stroke_color),
                               path_effects.Normal()])  # effet de style
    # Plot colorbar
    if weight:
        plt.colorbar(edges)
    plt.axis("off")