diff --git a/GEODE/Visualisation/ConfusionMatrix.py b/GEODE/Visualisation/ConfusionMatrix.py index df0f5aa6946f362dbb0716620c08e8307d891a15..b3c099c70336f0a38a2e16b9195148fb6fda6f9b 100644 --- a/GEODE/Visualisation/ConfusionMatrix.py +++ b/GEODE/Visualisation/ConfusionMatrix.py @@ -1,18 +1,9 @@ import argparse -from GEODE.Store import prepare, tabular -from GEODE.Visualisation.Legend import trim -import matplotlib.pyplot as plot +from GEODE.Store import JSON, tabular +from GEODE.Visualisation.Label import add_labels_argument, getLabels import pandas -import seaborn from sklearn.metrics import confusion_matrix -def heatmap(matrix, filePath, labels, **kwargs): - plot.figure(figsize=(8,7)) - ax = seaborn.heatmap( - matrix, xticklabels=labels, yticklabels=labels, **kwargs - ) - plot.savefig(prepare(filePath), dpi=300, bbox_inches='tight') - def fromList(data, labels): truth = [d['truth'] for d in data] if labels is None: @@ -34,34 +25,46 @@ def prepareData(data, labels=None): msg = "Unsupported data format {f} to represent a confusion matrix" raise Exception(msg.format(f=type(data))) -def drawConfusionMatrix(data, outputFile, labels=None, maxWidth=None, **kwargs): +def maxOutOfDiagonal(n, row): + l = len(row) + m = max([row[i] for i in range(l) if i != n]) + return [row[i] if i != n and row[i] == m else 0 for i in range(l)] + +def maxConfusionMatrix(m): + return [maxOutOfDiagonal(i, m[i]) for i in range(len(m))] + +def confusionMatrix(data, labels): truth, answers, labels = prepareData(data, labels=labels) matrix = confusion_matrix(truth, answers, labels=labels, normalize='true') - heatmap(matrix, outputFile, trim(labels, maxWidth), **kwargs) + return {'matrix': matrix.tolist(), 'labels': labels} + +def getConfusionMatrix(inputFile, labels): + if inputFile[-4:] == '.tsv': + return confusionMatrix(tabular(inputFile), labels) + elif inputFile[-5:] == '.json': + return JSON.load(inputFile) + +def extractConfusionMatrix(inputFile, outputJSON, labels=None, maximal=False): + data = getConfusionMatrix(inputFile, labels) + if maximal: + data['matrix'] = maxConfusionMatrix(data['matrix']) + JSON.save(data, outputJSON) def getArgs(arguments): - cli = argparse.ArgumentParser( - prog='confusionMatrix', - description="Draw a confusion matrix from the result of a prediction") - cli.add_argument('inputTSV') - cli.add_argument('outputPNG') - cli.add_argument('-l', '--labels', - help="path to a file containing one label per line") - cli.add_argument('-w', '--maxWidth', type=int, - help="length from which labels will be truncated") - cli.add_argument('-c', '--cmap', help="color map to use") + description = "Extract a confusion matrix from the result of a prediction" + cli = argparse.ArgumentParser(prog='confusionMatrix', + description=description) + cli.add_argument('inputFile') + cli.add_argument('outputJSON') + add_labels_argument(cli) + cli.add_argument('-m', '--maximal', + action='store_const', const=True, default=False) return cli.parse_args(arguments) -def drawConfusionMatrixCLI(arguments): +def extractConfusionMatrixCLI(arguments): args = getArgs(arguments) - data = tabular(args.inputTSV) - if args.labels is not None: - with open(args.labels, 'r') as labelsFile: - labels = list(map(lambda x: x.strip(), labelsFile)) - else: - labels = None - drawConfusionMatrix(data, - args.outputPNG, - labels=labels, - maxWidth=args.maxWidth, - cmap=args.cmap) + labels = getLabels(args) + extractConfusionMatrix(args.inputFile, + args.outputJSON, + labels=labels, + maximal=args.maximal) diff --git a/GEODE/Visualisation/DrawMatrix.py b/GEODE/Visualisation/DrawMatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf633dd3e1d18e3ba95f9acd9941d9ed6f9e4ff --- /dev/null +++ b/GEODE/Visualisation/DrawMatrix.py @@ -0,0 +1,36 @@ +import argparse +from GEODE.Store import prepare +from GEODE.Visualisation.ConfusionMatrix import getConfusionMatrix +from GEODE.Visualisation.Label import add_labels_argument, getLabels +from GEODE.Visualisation.Legend import trim +import matplotlib.pyplot as plot +import seaborn + +def drawMatrix(data, outputFile, maxWidth=None, **kwargs): + ticks = trim(data['labels'], maxWidth) + plot.figure(figsize=(8,7)) + ax = seaborn.heatmap( + data['matrix'], xticklabels=ticks, yticklabels=ticks, **kwargs + ) + plot.savefig(prepare(outputFile), dpi=300, bbox_inches='tight') + +def getArgs(arguments): + cli = argparse.ArgumentParser( + prog='drawMatrix', + description="Draw a matrix") + cli.add_argument('inputFile') + cli.add_argument('outputPNG') + add_labels_argument(cli) + cli.add_argument('-w', '--maxWidth', type=int, + help="length from which labels will be truncated") + cli.add_argument('-c', '--cmap', help="color map to use") + return cli.parse_args(arguments) + +def drawMatrixCLI(arguments): + args = getArgs(arguments) + labels = getLabels(args) + data = getConfusionMatrix(args.inputFile, labels) + drawMatrix(data, + args.outputPNG, + maxWidth=args.maxWidth, + cmap=args.cmap) diff --git a/GEODE/Visualisation/Label.py b/GEODE/Visualisation/Label.py new file mode 100644 index 0000000000000000000000000000000000000000..1d54fceddc698a4c01af77c1a452b3d66b4b6231 --- /dev/null +++ b/GEODE/Visualisation/Label.py @@ -0,0 +1,18 @@ +from argparse import Namespace + +def add_labels_argument(cli): + cli.add_argument('-l', '--labels', + help="path to a file containing one label per line") + +def getLabels(args): + if type(args) == Namespace: + args = vars(args) + if type(args) == dict and 'labels' in args: + path = args['labels'] + elif type(args) == str: + path = args + else: + path = None + if path is not None: + with open(path, 'r') as labelsFile: + return list(map(lambda x: x.strip(), labelsFile)) diff --git a/GEODE/Visualisation/__init__.py b/GEODE/Visualisation/__init__.py index 2eb7bce74fb7baa74ea7d1a36a095c0cfc315d43..fdb234dc9fb1553410fb2de3bebb40a38865c44f 100644 --- a/GEODE/Visualisation/__init__.py +++ b/GEODE/Visualisation/__init__.py @@ -1,3 +1,6 @@ -from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrix, heatmap -from GEODE.Visualisation.DensityProfile import densityProfile, drawDensityProfile, plotDensity +from GEODE.Visualisation.ConfusionMatrix import extractConfusionMatrix, \ + maxConfusionMatrix +from GEODE.Visualisation.DensityProfile import densityProfile, \ + drawDensityProfile, plotDensity +from GEODE.Visualisation.DrawMatrix import drawMatrix from GEODE.Visualisation.Legend import trim as legend diff --git a/GEODE/__init__.py b/GEODE/__init__.py index 57b432410b8436aecd9682ecced85ccaabb969f4..50b9fe0eac94b025ea1f54a73fbb63d0e138fb4e 100644 --- a/GEODE/__init__.py +++ b/GEODE/__init__.py @@ -22,14 +22,17 @@ from GEODE.Classification import fullLemmas, isStopWord, superdomains as domains from GEODE.ENE import eneLabels from GEODE.Metadata import article, articleKey, paragraph, paragraphKey, \ fromKey, relativePath, toKey, uid -from GEODE.Store import corpus, Directory, SelfContained, tabular, toTSV -from GEODE.Visualisation import densityProfile, heatmap, legend -from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrixCLI +from GEODE.Store import corpus, Directory, JSON, SelfContained, tabular, toTSV +from GEODE.Visualisation import densityProfile, drawMatrix, legend, \ + maxConfusionMatrix +from GEODE.Visualisation.ConfusionMatrix import extractConfusionMatrixCLI from GEODE.Visualisation.DensityProfile import drawDensityProfileCLI +from GEODE.Visualisation.DrawMatrix import drawMatrixCLI commands = { - 'confusionMatrix': drawConfusionMatrixCLI, - 'densityProfile': drawDensityProfileCLI + 'confusionMatrix': extractConfusionMatrixCLI, + 'densityProfile': drawDensityProfileCLI, + 'drawMatrix': drawMatrixCLI, } def geopyckCLI():