Skip to content
Snippets Groups Projects
visualisation.py 5.83 KiB
Newer Older
# coding = utf-8

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import re
import os

import networkx as nx


def get_graph_attr(fn, graph_dir):
    g_fn = os.path.join(graph_dir, fn)
    if not os.path.exists(g_fn):
        raise FileNotFoundError(g_fn)
    G = nx.read_gml(g_fn).graph
    return G


def get_sample_id_old(ch):
    id_graph = re.findall("\d+", ch)[0]
    if len(id_graph) == 3:
        return id_graph[-2:]
    else:
        return id_graph[-1:]


def get_sample_id(fn, file_format="gml"):
    return int(fn.strip(".{0}".format(file_format)).split("_")[-1])


def load_data(fn, graph_dir):
    df = pd.read_csv(fn, sep="\t")
    df["type_graph"] = df.filename.apply(lambda x: x[6:]).apply(lambda x: re.sub("_[\d]+.gml", "", x).replace("_", " "))
    df["parameters"] = df.filename.apply(lambda x: get_graph_attr(x, graph_dir))
Fize Jacques's avatar
Fize Jacques committed
    df["sample"] = df.filename.apply(get_sample_id)
    non_ne = {'random_prediction', 'common_neighbours', 'jaccard_coefficient', 'adamic_adar_index',
              'preferential_attachment', 'resource_allocation_index', 'stochastic_block_model',
              'stochastic_block_model_degree_corrected', 'spatial_link_prediction'}
    df["type_method"] = df.name.apply(lambda x: "heuristic" if x in non_ne else "network_embedding_based")
    return df


def set_custom_palette(x, y, max_color='red', other_color='lightgrey'):
    pal = []
    df = pd.concat((x, y), axis=1)
    min_df = df.groupby(x.name, as_index=False).mean()
    min_df[y.name] = min_df[y.name] - (df.groupby(x.name, as_index=False).sem()[y.name]) * 2

    max_df = df.groupby(x.name, as_index=False).mean()
    max_df[y.name] = max_df[y.name] + (df.groupby(x.name, as_index=False).sem()[y.name]) * 2
    max_min_row = min_df[y.name].argmax()

    max_min_key, max_min_value = min_df.iloc[max_min_row][x.name], min_df.iloc[max_min_row][y.name]
    col_per_method = {}
    if max_min_value > max_df[~(max_df[x.name] == max_min_key)][y.name].max():

        for k in x:
            if k == max_min_key:
                col_per_method[k] = max_color
            else:
                col_per_method[k] = other_color

    else:
        max_keys = max_df[max_df[y.name] > max_min_value][x.name].values.tolist()
        for k in x:
            if k in max_keys:
                col_per_method[k] = max_color
            else:
                col_per_method[k] = other_color
    for i, val in enumerate(x):
        pal.append(col_per_method[val])

    return pal

def highlight_barplot(x, y, **kwargs):
    if kwargs.get("palette", None):
        kwargs["palette"] = set_custom_palette(x, y)
        sns.barplot(x=x, y=y, **kwargs)
    else:
        sns.barplot(x=x, y=y, palette=set_custom_palette(x, y), **kwargs)

class DrawingResults():
    def __init__(self, df_results):
        self.df = df_results

    def __draw(self, g, **kwargs):

        if "figsize" in kwargs:
            g.fig.set_size_inches(*kwargs["figsize"])

        [plt.setp(ax.get_xticklabels(), rotation=kwargs.get("rotation", 90)) for ax in g.axes.flat]
Fize Jacques's avatar
Fize Jacques committed
        g.fig.subplots_adjust(wspace=.09)#, hspace=.02)

        if  kwargs.get("output_filename",None):
            save_params = {}
            if "save_param" in kwargs and type(kwargs["save_param"]) == dict:
                save_params.update(kwargs["save_param"])
            g.savefig(kwargs["output_filename"], **save_params)
        else:
            plt.show()

    def metric_per_nodes_edges(self, type_graph=None, agg_func=None,metric="auroc", **draw_args):
        new_df = self.df.copy()
        if agg_func:
            if agg_func in "mean max min std".split():
                new_df = new_df.groupby("name nb_edge size type_graph type_method".split(), as_index=False)
                new_df = getattr(new_df, agg_func)()
            else:
                raise ValueError("Method {0} does not exists in pandas.core.groupby.generic.DataFrameGroupBy".format(agg_func))

        if type_graph and type_graph in new_df.type_graph.unique():
            new_df = new_df[new_df.type_graph == type_graph].copy()

Fize Jacques's avatar
Fize Jacques committed
        g = sns.FacetGrid(new_df, row="size", col="nb_edge", margin_titles=True)

        plot_func = draw_args.get('plot_func', sns.barplot)
        g.map(plot_func, "name", metric)

        return self.__draw(g, **draw_args)

    def metric_global(self,  agg_func=None,metric="auroc", **draw_args):

        new_df = self.df.copy()
        if agg_func:
            new_df = self.df.groupby("name nb_edge size type_graph type_method".split(), as_index=False)
            if agg_func in "mean max min std".split():
                new_df = getattr(new_df,agg_func)()
                new_df = new_df.groupby("name type_graph type_method".split(), as_index=False)
                new_df = getattr(new_df, agg_func)()

            else:
                raise ValueError("Method {0} does not exists in pandas.core.groupby.generic.DataFrameGroupBy".format(agg_func))

Fize Jacques's avatar
Fize Jacques committed

        g = sns.FacetGrid(new_df,  col="type_graph", col_wrap=2, margin_titles=True)

        plot_func = draw_args.get('plot_func', sns.barplot)
        g.map(plot_func, "name", metric, palette="tab20")

        return self.__draw(g, **draw_args)

    def caracteristic_distribution(self, caracteristic, **draw_args):
        g = sns.FacetGrid(self.df, col="type_graph", col_wrap=4, )
        g.map(sns.histplot, caracteristic)

        return self.__draw(g, **draw_args)

    def parameter_impact(self, type_graph, parameter, second_parameter="size", metric="auroc", **draw_args):
        _df = self.df[self.df.type_graph == type_graph].copy()
        _df[parameter] = _df.parameters.apply(lambda x: x[parameter])

        g = sns.FacetGrid(_df, row=second_parameter, col=parameter, margin_titles=True, height=2.5)
        plot_func = draw_args.get('plot_func', sns.barplot)
Fize Jacques's avatar
Fize Jacques committed
        g.map(plot_func, "name", metric, palette=draw_args.get("cmap","tab20"))

        return self.__draw(g,**draw_args)