From 151b1b0f19ad7b9b5801aa0bb7bb7106854619fc Mon Sep 17 00:00:00 2001 From: Fize Jacques <jacques.fize@cirad.fr> Date: Wed, 17 Feb 2021 16:41:51 +0100 Subject: [PATCH] Debug visu func+ Move draw_visu to lib/visualisation + add graph atlas generator script + delete draw graph script --- draw_graph_script.py | 42 -------------------- generate_graph_atlas.py | 57 ++++++++++++++++++++++++++++ draw_visu.py => lib/visualisation.py | 39 ++++++++++++------- 3 files changed, 82 insertions(+), 56 deletions(-) delete mode 100644 draw_graph_script.py create mode 100644 generate_graph_atlas.py rename draw_visu.py => lib/visualisation.py (82%) diff --git a/draw_graph_script.py b/draw_graph_script.py deleted file mode 100644 index 19a66f4..0000000 --- a/draw_graph_script.py +++ /dev/null @@ -1,42 +0,0 @@ -# coding = utf-8 - -import argparse -import networkx as nx -import pandas as pd -import joblib -import json -import geopandas as gpd - -from lib.draw import draw - -parser = argparse.ArgumentParser() -parser.add_argument("input_file",help="edgelist format (sep = \",\" )") -parser.add_argument("output_file") -parser.add_argument("--country",help="if country node",action="store_true") -parser.add_argument("-w",action="store_true") - -args = parser.parse_args() - -G = nx.read_gexf(args.input_file) - -encoder = None -labels_dict = {} -positions = {} - -if args.country: - iso2_name = json.load(open("data/ISO3166-1.alpha2.json.txt")) - world = gpd.read_file("data/TM_WORLD_BORDERS-0/TM_WORLD_BORDERS-0.3.shp") - world["centroid_c"] = world.centroid - iso2_togeom = dict(world["ISO2 centroid_c".split()].values) - positions = {k: [v.x, v.y] for k, v in iso2_togeom.items() if k in G} - -for node in list(G.nodes()): - if args.country: - labels_dict[node] = iso2_name[node] - else: - labels_dict[node] = node - -fig, ax = draw(G,labels_dict,positions) -if args.country: - world.boundary.plot(ax=ax) -fig.savefig(args.output_file) \ No newline at end of file diff --git a/generate_graph_atlas.py b/generate_graph_atlas.py new file mode 100644 index 0000000..1c90fe9 --- /dev/null +++ b/generate_graph_atlas.py @@ -0,0 +1,57 @@ +# coding = utf-8 + +import networkx as nx +import matplotlib.pyplot as plt +import seaborn as sns +import glob +import numpy as np +import re +import pandas as pd +import argparse +import os + +parser = argparse.ArgumentParser() +parser.add_argument("graph_directory") +parser.add_argument("output_directory") +parser.add_argument("number_of_files",type=int) + +args = parser.parse_args() +graph_dir = args.graph_directory +if not os.path.exists(graph_dir): + raise FileNotFoundError("{0} does not exist".format(graph_dir)) + +fns = sorted(glob.glob(os.path.join(graph_dir,"*.gml"))) + +def draw_graph(fn, **kwargs): + G = nx.read_gml(fn.values[0]) + pos = nx.spring_layout(G) + + if "stochastic_block_model_graph" in fn.values[0]: + nx.draw(G, pos=pos, node_color=list(nx.get_node_attributes(G, "block").values()), node_size=20, + cmap=plt.cm.Dark2) + else: + nx.draw(G, pos=pos, node_size=20) + + +def parameter(G): + str_ = " \nVertices ({0}) Edges ({1}) ".format(len(G), G.size()) + for key, val in G.graph.items(): + if key == "nb_nodes" or key == "nb_edges": + continue + str_ += "\n{0} ({1})".format(key.replace("_", " "), val) + return str_ + + +df_fns = pd.DataFrame(fns, columns=["filename"]) +df_fns["index"] = np.arange(len(fns)) +df_fns["label"] = df_fns.filename.apply(lambda x: x.split("/")[-1][6:]) +df_fns["label"] = df_fns.label.apply( + lambda x: x.replace("stochastic_block_model_graph", "SBM").replace("_", " ").replace(".gml", "")) +df_fns["label"] = df_fns.apply(lambda x: x.label + parameter(nx.read_gml(x.filename)), axis=1) +for ix, split in enumerate(np.array_split(df_fns, args.number_of_files)): + g = sns.FacetGrid(split, col="label", col_wrap=4, height=5) + g.map(draw_graph, "filename") + axes = g.axes.flatten() + for ax in axes: + ax.set_title(re.sub("\d+ \n", "\n", ax.get_title()).strip("label = ")) + g.savefig(os.path.join(args.output_directory,"graph_atlas_fb{0}.pdf".format(ix)), bbox="tight_layout") diff --git a/draw_visu.py b/lib/visualisation.py similarity index 82% rename from draw_visu.py rename to lib/visualisation.py index 69d4738..aad6b85 100644 --- a/draw_visu.py +++ b/lib/visualisation.py @@ -42,23 +42,34 @@ def load_data(fn, graph_dir): return df -def set_custom_palette(x, y, max_color='red', close_color='turquoise', other_color='lightgrey'): - def get_color(x, max_val, min_diff): - if x == max_val: - return max_color - elif x > max_val - (0.01 + min_diff) and x < max_val + (0.01 + min_diff): - return close_color - else: - return other_color - +def set_custom_palette(x, y, max_color='red', other_color='lightgrey'): pal = [] df = pd.concat((x, y), axis=1) - mean_df = df.groupby(x.name, as_index=False).mean() - mean_per_x = dict(mean_df.values) - max_val = mean_df[y.name].max() - min_diff = (max_val - mean_df[y.name]).median() - col_per_method = {k: get_color(v, max_val, min_diff) for k, v in mean_per_x.items()} + min_df = df.groupby(x.name, as_index=False).mean() + min_df[y.name] = min_df[y.name] - (df.groupby(x.name, as_index=False).sem()[y.name]) * 2 + + max_df = df.groupby(x.name, as_index=False).mean() + max_df[y.name] = max_df[y.name] + (df.groupby(x.name, as_index=False).sem()[y.name]) * 2 + max_min_row = min_df[y.name].argmax() + max_min_key, max_min_value = min_df.iloc[max_min_row][x.name], min_df.iloc[max_min_row][y.name] + col_per_method = {} + if max_min_value > max_df[~(max_df[x.name] == max_min_key)][y.name].max(): + + for k in x: + if k == max_min_key: + col_per_method[k] = max_color + else: + col_per_method[k] = other_color + + else: + + max_keys = max_df[max_df[y.name] > max_min_value][x.name].values.tolist() + for k in x: + if k in max_keys: + col_per_method[k] = max_color + else: + col_per_method[k] = other_color for i, val in enumerate(x): pal.append(col_per_method[val]) -- GitLab