From 151b1b0f19ad7b9b5801aa0bb7bb7106854619fc Mon Sep 17 00:00:00 2001
From: Fize Jacques <jacques.fize@cirad.fr>
Date: Wed, 17 Feb 2021 16:41:51 +0100
Subject: [PATCH] Debug visu func+ Move draw_visu to lib/visualisation + add
 graph atlas generator script + delete draw graph script

---
 draw_graph_script.py                 | 42 --------------------
 generate_graph_atlas.py              | 57 ++++++++++++++++++++++++++++
 draw_visu.py => lib/visualisation.py | 39 ++++++++++++-------
 3 files changed, 82 insertions(+), 56 deletions(-)
 delete mode 100644 draw_graph_script.py
 create mode 100644 generate_graph_atlas.py
 rename draw_visu.py => lib/visualisation.py (82%)

diff --git a/draw_graph_script.py b/draw_graph_script.py
deleted file mode 100644
index 19a66f4..0000000
--- a/draw_graph_script.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# coding = utf-8
-
-import argparse
-import networkx as nx
-import pandas as pd
-import joblib
-import json
-import geopandas as gpd
-
-from lib.draw import draw
-
-parser = argparse.ArgumentParser()
-parser.add_argument("input_file",help="edgelist format (sep = \",\" )")
-parser.add_argument("output_file")
-parser.add_argument("--country",help="if country node",action="store_true")
-parser.add_argument("-w",action="store_true")
-
-args = parser.parse_args()
-
-G = nx.read_gexf(args.input_file)
-
-encoder = None
-labels_dict = {}
-positions = {}
-
-if args.country:
-    iso2_name = json.load(open("data/ISO3166-1.alpha2.json.txt"))
-    world = gpd.read_file("data/TM_WORLD_BORDERS-0/TM_WORLD_BORDERS-0.3.shp")
-    world["centroid_c"] = world.centroid
-    iso2_togeom = dict(world["ISO2 centroid_c".split()].values)
-    positions = {k: [v.x, v.y] for k, v in iso2_togeom.items() if k in G}
-
-for node in list(G.nodes()):
-    if args.country:
-        labels_dict[node] = iso2_name[node]
-    else:
-        labels_dict[node] = node
-
-fig, ax = draw(G,labels_dict,positions)
-if args.country:
-    world.boundary.plot(ax=ax)
-fig.savefig(args.output_file)
\ No newline at end of file
diff --git a/generate_graph_atlas.py b/generate_graph_atlas.py
new file mode 100644
index 0000000..1c90fe9
--- /dev/null
+++ b/generate_graph_atlas.py
@@ -0,0 +1,57 @@
+# coding = utf-8
+
+import networkx as nx
+import matplotlib.pyplot as plt
+import seaborn as sns
+import glob
+import numpy as np
+import re
+import pandas as pd
+import argparse
+import os
+
+parser = argparse.ArgumentParser()
+parser.add_argument("graph_directory")
+parser.add_argument("output_directory")
+parser.add_argument("number_of_files",type=int)
+
+args = parser.parse_args()
+graph_dir = args.graph_directory
+if not os.path.exists(graph_dir):
+    raise FileNotFoundError("{0} does not exist".format(graph_dir))
+
+fns = sorted(glob.glob(os.path.join(graph_dir,"*.gml")))
+
+def draw_graph(fn, **kwargs):
+    G = nx.read_gml(fn.values[0])
+    pos = nx.spring_layout(G)
+
+    if "stochastic_block_model_graph" in fn.values[0]:
+        nx.draw(G, pos=pos, node_color=list(nx.get_node_attributes(G, "block").values()), node_size=20,
+                cmap=plt.cm.Dark2)
+    else:
+        nx.draw(G, pos=pos, node_size=20)
+
+
+def parameter(G):
+    str_ = " \nVertices ({0}) Edges ({1}) ".format(len(G), G.size())
+    for key, val in G.graph.items():
+        if key == "nb_nodes" or key == "nb_edges":
+            continue
+        str_ += "\n{0} ({1})".format(key.replace("_", " "), val)
+    return str_
+
+
+df_fns = pd.DataFrame(fns, columns=["filename"])
+df_fns["index"] = np.arange(len(fns))
+df_fns["label"] = df_fns.filename.apply(lambda x: x.split("/")[-1][6:])
+df_fns["label"] = df_fns.label.apply(
+    lambda x: x.replace("stochastic_block_model_graph", "SBM").replace("_", " ").replace(".gml", ""))
+df_fns["label"] = df_fns.apply(lambda x: x.label + parameter(nx.read_gml(x.filename)), axis=1)
+for ix, split in enumerate(np.array_split(df_fns, args.number_of_files)):
+    g = sns.FacetGrid(split, col="label", col_wrap=4, height=5)
+    g.map(draw_graph, "filename")
+    axes = g.axes.flatten()
+    for ax in axes:
+        ax.set_title(re.sub("\d+ \n", "\n", ax.get_title()).strip("label = "))
+    g.savefig(os.path.join(args.output_directory,"graph_atlas_fb{0}.pdf".format(ix)), bbox="tight_layout")
diff --git a/draw_visu.py b/lib/visualisation.py
similarity index 82%
rename from draw_visu.py
rename to lib/visualisation.py
index 69d4738..aad6b85 100644
--- a/draw_visu.py
+++ b/lib/visualisation.py
@@ -42,23 +42,34 @@ def load_data(fn, graph_dir):
     return df
 
 
-def set_custom_palette(x, y, max_color='red', close_color='turquoise', other_color='lightgrey'):
-    def get_color(x, max_val, min_diff):
-        if x == max_val:
-            return max_color
-        elif x > max_val - (0.01 + min_diff) and x < max_val + (0.01 + min_diff):
-            return close_color
-        else:
-            return other_color
-
+def set_custom_palette(x, y, max_color='red', other_color='lightgrey'):
     pal = []
     df = pd.concat((x, y), axis=1)
-    mean_df = df.groupby(x.name, as_index=False).mean()
-    mean_per_x = dict(mean_df.values)
-    max_val = mean_df[y.name].max()
-    min_diff = (max_val - mean_df[y.name]).median()
-    col_per_method = {k: get_color(v, max_val, min_diff) for k, v in mean_per_x.items()}
+    min_df = df.groupby(x.name, as_index=False).mean()
+    min_df[y.name] = min_df[y.name] - (df.groupby(x.name, as_index=False).sem()[y.name]) * 2
+
+    max_df = df.groupby(x.name, as_index=False).mean()
+    max_df[y.name] = max_df[y.name] + (df.groupby(x.name, as_index=False).sem()[y.name]) * 2
+    max_min_row = min_df[y.name].argmax()
 
+    max_min_key, max_min_value = min_df.iloc[max_min_row][x.name], min_df.iloc[max_min_row][y.name]
+    col_per_method = {}
+    if max_min_value > max_df[~(max_df[x.name] == max_min_key)][y.name].max():
+
+        for k in x:
+            if k == max_min_key:
+                col_per_method[k] = max_color
+            else:
+                col_per_method[k] = other_color
+
+    else:
+
+        max_keys = max_df[max_df[y.name] > max_min_value][x.name].values.tolist()
+        for k in x:
+            if k in max_keys:
+                col_per_method[k] = max_color
+            else:
+                col_per_method[k] = other_color
     for i, val in enumerate(x):
         pal.append(col_per_method[val])
 
-- 
GitLab