diff --git a/GEODE/Visualisation/Graph.py b/GEODE/Visualisation/Graph.py new file mode 100644 index 0000000000000000000000000000000000000000..de489979aebd7621d20fc18e714570dc4cbf0e0b --- /dev/null +++ b/GEODE/Visualisation/Graph.py @@ -0,0 +1,62 @@ +import argparse +from GEODE.Store import prepare +from GEODE.Visualisation.ConfusionMatrix import getConfusionMatrix +from GEODE.Visualisation.Legend import trim +from graphviz import Digraph +from matplotlib.cm import ScalarMappable +from matplotlib.colors import Normalize +from seaborn import color_palette + +def hexColorRange(vmin, vmax, cmap): + mapper = ScalarMappable( + norm=Normalize(vmin=vmin, vmax=vmax, clip=True), + cmap=color_palette(cmap, as_cmap=True) + ) + def toHexRGB(color): + (r, g, b, _) = mapper.to_rgba(color) + return f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}" + return toHexRGB + +def withBorder(color): + border = ["#3b3b3b"] + return ':'.join(border + 3*[color] + border) + +def drawGraph(data, outputFile, cmap='Blues', maxWidth=None): + matrix = data['matrix'] + labels = trim(data['labels'], maxWidth) + edgeValues = [c for row in matrix for c in row if c is not None] + colorize = hexColorRange(min(edgeValues), max(edgeValues), cmap) + g = Digraph() + g.graph_attr['rankdir'] = 'LR' + g.node_attr['fontsize'] = g.edge_attr['fontsize'] = '22' + g.edge_attr['arrowsize'] = g.edge_attr['penwidth'] = '2' + dimension = len(matrix) + for i in range(0, dimension): + g.node(str(i), label=labels[i]) + for i in range(0, dimension): + for j in range(0, len(matrix[i])): + link = matrix[i][j] + if link is not None and link > 0: + label = f"{link}" if type(link) == int else f"{link:.2f}" + color=withBorder(colorize(link)) + g.edge(str(i), str(j), color=color, label=label) + return g.render(prepare(outputFile[:-4]), format=outputFile[-3:]) + +def getArgs(arguments): + cli = argparse.ArgumentParser( + prog='graph', + description="Draw a graph from its weighted adjacency matrix") + cli.add_argument('inputJSON') + cli.add_argument('outputPNG') + cli.add_argument('-w', '--maxWidth', type=int, + help="length from which labels will be truncated") + cli.add_argument('-c', '--cmap', help="color map to use") + return cli.parse_args(arguments) + +def drawGraphCLI(arguments): + args = getArgs(arguments) + data = getConfusionMatrix(args.inputJSON) + drawGraph(data, + args.outputPNG, + maxWidth=args.maxWidth, + cmap=args.cmap) diff --git a/GEODE/Visualisation/__init__.py b/GEODE/Visualisation/__init__.py index fdb234dc9fb1553410fb2de3bebb40a38865c44f..3cc602b36c656a14eadf6cdd1efdccd12b4e90bb 100644 --- a/GEODE/Visualisation/__init__.py +++ b/GEODE/Visualisation/__init__.py @@ -3,4 +3,5 @@ from GEODE.Visualisation.ConfusionMatrix import extractConfusionMatrix, \ from GEODE.Visualisation.DensityProfile import densityProfile, \ drawDensityProfile, plotDensity from GEODE.Visualisation.DrawMatrix import drawMatrix +from GEODE.Visualisation.Graph import drawGraph from GEODE.Visualisation.Legend import trim as legend diff --git a/GEODE/__init__.py b/GEODE/__init__.py index 50b9fe0eac94b025ea1f54a73fbb63d0e138fb4e..a3fd4123c92aef7df0134d17b3f8fca292077863 100644 --- a/GEODE/__init__.py +++ b/GEODE/__init__.py @@ -23,16 +23,18 @@ from GEODE.ENE import eneLabels from GEODE.Metadata import article, articleKey, paragraph, paragraphKey, \ fromKey, relativePath, toKey, uid from GEODE.Store import corpus, Directory, JSON, SelfContained, tabular, toTSV -from GEODE.Visualisation import densityProfile, drawMatrix, legend, \ +from GEODE.Visualisation import densityProfile, drawMatrix, drawGraph, legend, \ maxConfusionMatrix from GEODE.Visualisation.ConfusionMatrix import extractConfusionMatrixCLI from GEODE.Visualisation.DensityProfile import drawDensityProfileCLI from GEODE.Visualisation.DrawMatrix import drawMatrixCLI +from GEODE.Visualisation.Graph import drawGraphCLI commands = { 'confusionMatrix': extractConfusionMatrixCLI, 'densityProfile': drawDensityProfileCLI, 'drawMatrix': drawMatrixCLI, + 'graph': drawGraphCLI } def geopyckCLI(): diff --git a/guix.scm b/guix.scm index 8488c739744c29ac6b3d9df0402e8f37c812ed3c..da8e31c80bd7c30663a6a61e50d6ba6de33bbb58 100644 --- a/guix.scm +++ b/guix.scm @@ -1,4 +1,5 @@ -(use-modules ((gnu packages machine-learning) #:select (python-pytorch python-scikit-learn python-spacy)) +(use-modules ((gnu packages graphviz) #:select (python-graphviz)) + ((gnu packages machine-learning) #:select (python-pytorch python-scikit-learn python-spacy)) ((gnu packages python-science) #:select (python-pandas)) ((gnu packages python-xyz) #:select (python-matplotlib python-nltk @@ -25,6 +26,7 @@ (propagated-inputs (list nltk-data-corpora-stopwords python-frenchleffflemmatizer + python-graphviz python-matplotlib python-nltk python-pandas