diff --git a/GEODE/Visualisation/ConfusionMatrix.py b/GEODE/Visualisation/ConfusionMatrix.py index af140f30b697f8544382db92c4a43cc225b4114f..df0f5aa6946f362dbb0716620c08e8307d891a15 100644 --- a/GEODE/Visualisation/ConfusionMatrix.py +++ b/GEODE/Visualisation/ConfusionMatrix.py @@ -1,22 +1,13 @@ import argparse from GEODE.Store import prepare, tabular +from GEODE.Visualisation.Legend import trim import matplotlib.pyplot as plot import pandas import seaborn from sklearn.metrics import confusion_matrix -def trim(name, maxSize): - if len(name) > maxSize: - components = name.split(' ') - return components[0] + ' […]' - else: - return name - -def trimLabels(labels, maxWidth): - return labels if maxWidth is None else [trim(l, maxWidth) for l in labels] - def heatmap(matrix, filePath, labels, **kwargs): - plot.figure(figsize=(16,13)) + plot.figure(figsize=(8,7)) ax = seaborn.heatmap( matrix, xticklabels=labels, yticklabels=labels, **kwargs ) @@ -46,7 +37,7 @@ def prepareData(data, labels=None): def drawConfusionMatrix(data, outputFile, labels=None, maxWidth=None, **kwargs): truth, answers, labels = prepareData(data, labels=labels) matrix = confusion_matrix(truth, answers, labels=labels, normalize='true') - heatmap(matrix, outputFile, trimLabels(labels, maxWidth), **kwargs) + heatmap(matrix, outputFile, trim(labels, maxWidth), **kwargs) def getArgs(arguments): cli = argparse.ArgumentParser( diff --git a/GEODE/Visualisation/Legend.py b/GEODE/Visualisation/Legend.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa292ea0b7de98b8be29921b238c107f37f9ec8 --- /dev/null +++ b/GEODE/Visualisation/Legend.py @@ -0,0 +1,23 @@ +from GEODE.Functional import curry + +@curry +def take(maxWidth, shards): + result = [] + budget = maxWidth + i = 0 + while i < len(shards) and budget > 1: + shard = shards[i] + if len(shard) > budget: + shard = f"{shard[:budget-1]}." + result.append(shard) + budget -= len(shard) + 1 + i += 1 + return tuple(result) + +def trim(labels, maxWidth=10): + if maxWidth is None: + return labels + else: + shards = [label.split(' ') for label in labels] + prefixes = [*map(take(maxWidth), shards)] + return [' '.join(prefix) for prefix in sorted(prefixes)] diff --git a/GEODE/Visualisation/__init__.py b/GEODE/Visualisation/__init__.py index 981a91757c5b5d4f6811f59844c12101ecc2d4f3..2eb7bce74fb7baa74ea7d1a36a095c0cfc315d43 100644 --- a/GEODE/Visualisation/__init__.py +++ b/GEODE/Visualisation/__init__.py @@ -1,2 +1,3 @@ from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrix, heatmap from GEODE.Visualisation.DensityProfile import densityProfile, drawDensityProfile, plotDensity +from GEODE.Visualisation.Legend import trim as legend diff --git a/GEODE/__init__.py b/GEODE/__init__.py index 563bce75422c75fc354eee084877a9af0ead7836..e7735f5fba302172f401b79b3b2fb6223885b756 100644 --- a/GEODE/__init__.py +++ b/GEODE/__init__.py @@ -23,7 +23,7 @@ from GEODE.ENE import eneLabels from GEODE.Metadata import article, articleKey, paragraph, paragraphKey, \ fromKey, relativePath, toKey, uid from GEODE.Store import corpus, Directory, SelfContained, tabular, toTSV -from GEODE.Visualisation import densityProfile, heatmap +from GEODE.Visualisation import densityProfile, heatmap, legend from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrixCLI from GEODE.Visualisation.DensityProfile import drawDensityProfileCLI