Skip to content
Snippets Groups Projects
Commit 4ed9e1d7 authored by Alice Brenon's avatar Alice Brenon
Browse files

Expose label-trimming logic as a separate module for reuse

parent 4babdc56
No related branches found
No related tags found
No related merge requests found
import argparse import argparse
from GEODE.Store import prepare, tabular from GEODE.Store import prepare, tabular
from GEODE.Visualisation.Legend import trim
import matplotlib.pyplot as plot import matplotlib.pyplot as plot
import pandas import pandas
import seaborn import seaborn
from sklearn.metrics import confusion_matrix from sklearn.metrics import confusion_matrix
def trim(name, maxSize):
if len(name) > maxSize:
components = name.split(' ')
return components[0] + ' […]'
else:
return name
def trimLabels(labels, maxWidth):
return labels if maxWidth is None else [trim(l, maxWidth) for l in labels]
def heatmap(matrix, filePath, labels, **kwargs): def heatmap(matrix, filePath, labels, **kwargs):
plot.figure(figsize=(16,13)) plot.figure(figsize=(8,7))
ax = seaborn.heatmap( ax = seaborn.heatmap(
matrix, xticklabels=labels, yticklabels=labels, **kwargs matrix, xticklabels=labels, yticklabels=labels, **kwargs
) )
...@@ -46,7 +37,7 @@ def prepareData(data, labels=None): ...@@ -46,7 +37,7 @@ def prepareData(data, labels=None):
def drawConfusionMatrix(data, outputFile, labels=None, maxWidth=None, **kwargs): def drawConfusionMatrix(data, outputFile, labels=None, maxWidth=None, **kwargs):
truth, answers, labels = prepareData(data, labels=labels) truth, answers, labels = prepareData(data, labels=labels)
matrix = confusion_matrix(truth, answers, labels=labels, normalize='true') matrix = confusion_matrix(truth, answers, labels=labels, normalize='true')
heatmap(matrix, outputFile, trimLabels(labels, maxWidth), **kwargs) heatmap(matrix, outputFile, trim(labels, maxWidth), **kwargs)
def getArgs(arguments): def getArgs(arguments):
cli = argparse.ArgumentParser( cli = argparse.ArgumentParser(
......
from GEODE.Functional import curry
@curry
def take(maxWidth, shards):
result = []
budget = maxWidth
i = 0
while i < len(shards) and budget > 1:
shard = shards[i]
if len(shard) > budget:
shard = f"{shard[:budget-1]}."
result.append(shard)
budget -= len(shard) + 1
i += 1
return tuple(result)
def trim(labels, maxWidth=10):
if maxWidth is None:
return labels
else:
shards = [label.split(' ') for label in labels]
prefixes = [*map(take(maxWidth), shards)]
return [' '.join(prefix) for prefix in sorted(prefixes)]
from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrix, heatmap from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrix, heatmap
from GEODE.Visualisation.DensityProfile import densityProfile, drawDensityProfile, plotDensity from GEODE.Visualisation.DensityProfile import densityProfile, drawDensityProfile, plotDensity
from GEODE.Visualisation.Legend import trim as legend
...@@ -23,7 +23,7 @@ from GEODE.ENE import eneLabels ...@@ -23,7 +23,7 @@ from GEODE.ENE import eneLabels
from GEODE.Metadata import article, articleKey, paragraph, paragraphKey, \ from GEODE.Metadata import article, articleKey, paragraph, paragraphKey, \
fromKey, relativePath, toKey, uid fromKey, relativePath, toKey, uid
from GEODE.Store import corpus, Directory, SelfContained, tabular, toTSV from GEODE.Store import corpus, Directory, SelfContained, tabular, toTSV
from GEODE.Visualisation import densityProfile, heatmap from GEODE.Visualisation import densityProfile, heatmap, legend
from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrixCLI from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrixCLI
from GEODE.Visualisation.DensityProfile import drawDensityProfileCLI from GEODE.Visualisation.DensityProfile import drawDensityProfileCLI
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment