From 4ed9e1d795bc3689cf213cbbfda7353acdfcaa83 Mon Sep 17 00:00:00 2001 From: Alice BRENON <alice.brenon@ens-lyon.fr> Date: Thu, 21 Mar 2024 15:10:51 +0100 Subject: [PATCH] Expose label-trimming logic as a separate module for reuse --- GEODE/Visualisation/ConfusionMatrix.py | 15 +++------------ GEODE/Visualisation/Legend.py | 23 +++++++++++++++++++++++ GEODE/Visualisation/__init__.py | 1 + GEODE/__init__.py | 2 +- 4 files changed, 28 insertions(+), 13 deletions(-) create mode 100644 GEODE/Visualisation/Legend.py diff --git a/GEODE/Visualisation/ConfusionMatrix.py b/GEODE/Visualisation/ConfusionMatrix.py index af140f3..df0f5aa 100644 --- a/GEODE/Visualisation/ConfusionMatrix.py +++ b/GEODE/Visualisation/ConfusionMatrix.py @@ -1,22 +1,13 @@ import argparse from GEODE.Store import prepare, tabular +from GEODE.Visualisation.Legend import trim import matplotlib.pyplot as plot import pandas import seaborn from sklearn.metrics import confusion_matrix -def trim(name, maxSize): - if len(name) > maxSize: - components = name.split(' ') - return components[0] + ' […]' - else: - return name - -def trimLabels(labels, maxWidth): - return labels if maxWidth is None else [trim(l, maxWidth) for l in labels] - def heatmap(matrix, filePath, labels, **kwargs): - plot.figure(figsize=(16,13)) + plot.figure(figsize=(8,7)) ax = seaborn.heatmap( matrix, xticklabels=labels, yticklabels=labels, **kwargs ) @@ -46,7 +37,7 @@ def prepareData(data, labels=None): def drawConfusionMatrix(data, outputFile, labels=None, maxWidth=None, **kwargs): truth, answers, labels = prepareData(data, labels=labels) matrix = confusion_matrix(truth, answers, labels=labels, normalize='true') - heatmap(matrix, outputFile, trimLabels(labels, maxWidth), **kwargs) + heatmap(matrix, outputFile, trim(labels, maxWidth), **kwargs) def getArgs(arguments): cli = argparse.ArgumentParser( diff --git a/GEODE/Visualisation/Legend.py b/GEODE/Visualisation/Legend.py new file mode 100644 index 0000000..7aa292e --- /dev/null +++ b/GEODE/Visualisation/Legend.py @@ -0,0 +1,23 @@ +from GEODE.Functional import curry + +@curry +def take(maxWidth, shards): + result = [] + budget = maxWidth + i = 0 + while i < len(shards) and budget > 1: + shard = shards[i] + if len(shard) > budget: + shard = f"{shard[:budget-1]}." + result.append(shard) + budget -= len(shard) + 1 + i += 1 + return tuple(result) + +def trim(labels, maxWidth=10): + if maxWidth is None: + return labels + else: + shards = [label.split(' ') for label in labels] + prefixes = [*map(take(maxWidth), shards)] + return [' '.join(prefix) for prefix in sorted(prefixes)] diff --git a/GEODE/Visualisation/__init__.py b/GEODE/Visualisation/__init__.py index 981a917..2eb7bce 100644 --- a/GEODE/Visualisation/__init__.py +++ b/GEODE/Visualisation/__init__.py @@ -1,2 +1,3 @@ from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrix, heatmap from GEODE.Visualisation.DensityProfile import densityProfile, drawDensityProfile, plotDensity +from GEODE.Visualisation.Legend import trim as legend diff --git a/GEODE/__init__.py b/GEODE/__init__.py index 563bce7..e7735f5 100644 --- a/GEODE/__init__.py +++ b/GEODE/__init__.py @@ -23,7 +23,7 @@ from GEODE.ENE import eneLabels from GEODE.Metadata import article, articleKey, paragraph, paragraphKey, \ fromKey, relativePath, toKey, uid from GEODE.Store import corpus, Directory, SelfContained, tabular, toTSV -from GEODE.Visualisation import densityProfile, heatmap +from GEODE.Visualisation import densityProfile, heatmap, legend from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrixCLI from GEODE.Visualisation.DensityProfile import drawDensityProfileCLI -- GitLab