From 151b1b0f19ad7b9b5801aa0bb7bb7106854619fc Mon Sep 17 00:00:00 2001
From: Fize Jacques <jacques.fize@cirad.fr>
Date: Wed, 17 Feb 2021 16:41:51 +0100
Subject: [PATCH] Debug visu func+ Move draw_visu to lib/visualisation + add
graph atlas generator script + delete draw graph script
---
draw_graph_script.py | 42 --------------------
generate_graph_atlas.py | 57 ++++++++++++++++++++++++++++
draw_visu.py => lib/visualisation.py | 39 ++++++++++++-------
3 files changed, 82 insertions(+), 56 deletions(-)
delete mode 100644 draw_graph_script.py
create mode 100644 generate_graph_atlas.py
rename draw_visu.py => lib/visualisation.py (82%)
diff --git a/draw_graph_script.py b/draw_graph_script.py
deleted file mode 100644
index 19a66f4..0000000
--- a/draw_graph_script.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# coding = utf-8
-
-import argparse
-import networkx as nx
-import pandas as pd
-import joblib
-import json
-import geopandas as gpd
-
-from lib.draw import draw
-
-parser = argparse.ArgumentParser()
-parser.add_argument("input_file",help="edgelist format (sep = \",\" )")
-parser.add_argument("output_file")
-parser.add_argument("--country",help="if country node",action="store_true")
-parser.add_argument("-w",action="store_true")
-
-args = parser.parse_args()
-
-G = nx.read_gexf(args.input_file)
-
-encoder = None
-labels_dict = {}
-positions = {}
-
-if args.country:
- iso2_name = json.load(open("data/ISO3166-1.alpha2.json.txt"))
- world = gpd.read_file("data/TM_WORLD_BORDERS-0/TM_WORLD_BORDERS-0.3.shp")
- world["centroid_c"] = world.centroid
- iso2_togeom = dict(world["ISO2 centroid_c".split()].values)
- positions = {k: [v.x, v.y] for k, v in iso2_togeom.items() if k in G}
-
-for node in list(G.nodes()):
- if args.country:
- labels_dict[node] = iso2_name[node]
- else:
- labels_dict[node] = node
-
-fig, ax = draw(G,labels_dict,positions)
-if args.country:
- world.boundary.plot(ax=ax)
-fig.savefig(args.output_file)
\ No newline at end of file
diff --git a/generate_graph_atlas.py b/generate_graph_atlas.py
new file mode 100644
index 0000000..1c90fe9
--- /dev/null
+++ b/generate_graph_atlas.py
@@ -0,0 +1,57 @@
+# coding = utf-8
+
+import networkx as nx
+import matplotlib.pyplot as plt
+import seaborn as sns
+import glob
+import numpy as np
+import re
+import pandas as pd
+import argparse
+import os
+
+parser = argparse.ArgumentParser()
+parser.add_argument("graph_directory")
+parser.add_argument("output_directory")
+parser.add_argument("number_of_files",type=int)
+
+args = parser.parse_args()
+graph_dir = args.graph_directory
+if not os.path.exists(graph_dir):
+ raise FileNotFoundError("{0} does not exist".format(graph_dir))
+
+fns = sorted(glob.glob(os.path.join(graph_dir,"*.gml")))
+
+def draw_graph(fn, **kwargs):
+ G = nx.read_gml(fn.values[0])
+ pos = nx.spring_layout(G)
+
+ if "stochastic_block_model_graph" in fn.values[0]:
+ nx.draw(G, pos=pos, node_color=list(nx.get_node_attributes(G, "block").values()), node_size=20,
+ cmap=plt.cm.Dark2)
+ else:
+ nx.draw(G, pos=pos, node_size=20)
+
+
+def parameter(G):
+ str_ = " \nVertices ({0}) Edges ({1}) ".format(len(G), G.size())
+ for key, val in G.graph.items():
+ if key == "nb_nodes" or key == "nb_edges":
+ continue
+ str_ += "\n{0} ({1})".format(key.replace("_", " "), val)
+ return str_
+
+
+df_fns = pd.DataFrame(fns, columns=["filename"])
+df_fns["index"] = np.arange(len(fns))
+df_fns["label"] = df_fns.filename.apply(lambda x: x.split("/")[-1][6:])
+df_fns["label"] = df_fns.label.apply(
+ lambda x: x.replace("stochastic_block_model_graph", "SBM").replace("_", " ").replace(".gml", ""))
+df_fns["label"] = df_fns.apply(lambda x: x.label + parameter(nx.read_gml(x.filename)), axis=1)
+for ix, split in enumerate(np.array_split(df_fns, args.number_of_files)):
+ g = sns.FacetGrid(split, col="label", col_wrap=4, height=5)
+ g.map(draw_graph, "filename")
+ axes = g.axes.flatten()
+ for ax in axes:
+ ax.set_title(re.sub("\d+ \n", "\n", ax.get_title()).strip("label = "))
+ g.savefig(os.path.join(args.output_directory,"graph_atlas_fb{0}.pdf".format(ix)), bbox="tight_layout")
diff --git a/draw_visu.py b/lib/visualisation.py
similarity index 82%
rename from draw_visu.py
rename to lib/visualisation.py
index 69d4738..aad6b85 100644
--- a/draw_visu.py
+++ b/lib/visualisation.py
@@ -42,23 +42,34 @@ def load_data(fn, graph_dir):
return df
-def set_custom_palette(x, y, max_color='red', close_color='turquoise', other_color='lightgrey'):
- def get_color(x, max_val, min_diff):
- if x == max_val:
- return max_color
- elif x > max_val - (0.01 + min_diff) and x < max_val + (0.01 + min_diff):
- return close_color
- else:
- return other_color
-
+def set_custom_palette(x, y, max_color='red', other_color='lightgrey'):
pal = []
df = pd.concat((x, y), axis=1)
- mean_df = df.groupby(x.name, as_index=False).mean()
- mean_per_x = dict(mean_df.values)
- max_val = mean_df[y.name].max()
- min_diff = (max_val - mean_df[y.name]).median()
- col_per_method = {k: get_color(v, max_val, min_diff) for k, v in mean_per_x.items()}
+ min_df = df.groupby(x.name, as_index=False).mean()
+ min_df[y.name] = min_df[y.name] - (df.groupby(x.name, as_index=False).sem()[y.name]) * 2
+
+ max_df = df.groupby(x.name, as_index=False).mean()
+ max_df[y.name] = max_df[y.name] + (df.groupby(x.name, as_index=False).sem()[y.name]) * 2
+ max_min_row = min_df[y.name].argmax()
+ max_min_key, max_min_value = min_df.iloc[max_min_row][x.name], min_df.iloc[max_min_row][y.name]
+ col_per_method = {}
+ if max_min_value > max_df[~(max_df[x.name] == max_min_key)][y.name].max():
+
+ for k in x:
+ if k == max_min_key:
+ col_per_method[k] = max_color
+ else:
+ col_per_method[k] = other_color
+
+ else:
+
+ max_keys = max_df[max_df[y.name] > max_min_value][x.name].values.tolist()
+ for k in x:
+ if k in max_keys:
+ col_per_method[k] = max_color
+ else:
+ col_per_method[k] = other_color
for i, val in enumerate(x):
pal.append(col_per_method[val])
--
GitLab