Skip to content
Snippets Groups Projects
Commit b71fe134 authored by Alice Brenon's avatar Alice Brenon
Browse files

Separate matrix handling from matrix ploting + add support for the Maximal Confusion Matrices

parent 414fd55c
No related branches found
No related tags found
No related merge requests found
import argparse
from GEODE.Store import prepare, tabular
from GEODE.Visualisation.Legend import trim
import matplotlib.pyplot as plot
from GEODE.Store import JSON, tabular
from GEODE.Visualisation.Label import add_labels_argument, getLabels
import pandas
import seaborn
from sklearn.metrics import confusion_matrix
def heatmap(matrix, filePath, labels, **kwargs):
plot.figure(figsize=(8,7))
ax = seaborn.heatmap(
matrix, xticklabels=labels, yticklabels=labels, **kwargs
)
plot.savefig(prepare(filePath), dpi=300, bbox_inches='tight')
def fromList(data, labels):
truth = [d['truth'] for d in data]
if labels is None:
......@@ -34,34 +25,46 @@ def prepareData(data, labels=None):
msg = "Unsupported data format {f} to represent a confusion matrix"
raise Exception(msg.format(f=type(data)))
def drawConfusionMatrix(data, outputFile, labels=None, maxWidth=None, **kwargs):
def maxOutOfDiagonal(n, row):
l = len(row)
m = max([row[i] for i in range(l) if i != n])
return [row[i] if i != n and row[i] == m else 0 for i in range(l)]
def maxConfusionMatrix(m):
return [maxOutOfDiagonal(i, m[i]) for i in range(len(m))]
def confusionMatrix(data, labels):
truth, answers, labels = prepareData(data, labels=labels)
matrix = confusion_matrix(truth, answers, labels=labels, normalize='true')
heatmap(matrix, outputFile, trim(labels, maxWidth), **kwargs)
return {'matrix': matrix.tolist(), 'labels': labels}
def getConfusionMatrix(inputFile, labels):
if inputFile[-4:] == '.tsv':
return confusionMatrix(tabular(inputFile), labels)
elif inputFile[-5:] == '.json':
return JSON.load(inputFile)
def extractConfusionMatrix(inputFile, outputJSON, labels=None, maximal=False):
data = getConfusionMatrix(inputFile, labels)
if maximal:
data['matrix'] = maxConfusionMatrix(data['matrix'])
JSON.save(data, outputJSON)
def getArgs(arguments):
cli = argparse.ArgumentParser(
prog='confusionMatrix',
description="Draw a confusion matrix from the result of a prediction")
cli.add_argument('inputTSV')
cli.add_argument('outputPNG')
cli.add_argument('-l', '--labels',
help="path to a file containing one label per line")
cli.add_argument('-w', '--maxWidth', type=int,
help="length from which labels will be truncated")
cli.add_argument('-c', '--cmap', help="color map to use")
description = "Extract a confusion matrix from the result of a prediction"
cli = argparse.ArgumentParser(prog='confusionMatrix',
description=description)
cli.add_argument('inputFile')
cli.add_argument('outputJSON')
add_labels_argument(cli)
cli.add_argument('-m', '--maximal',
action='store_const', const=True, default=False)
return cli.parse_args(arguments)
def drawConfusionMatrixCLI(arguments):
def extractConfusionMatrixCLI(arguments):
args = getArgs(arguments)
data = tabular(args.inputTSV)
if args.labels is not None:
with open(args.labels, 'r') as labelsFile:
labels = list(map(lambda x: x.strip(), labelsFile))
else:
labels = None
drawConfusionMatrix(data,
args.outputPNG,
labels=labels,
maxWidth=args.maxWidth,
cmap=args.cmap)
labels = getLabels(args)
extractConfusionMatrix(args.inputFile,
args.outputJSON,
labels=labels,
maximal=args.maximal)
import argparse
from GEODE.Store import prepare
from GEODE.Visualisation.ConfusionMatrix import getConfusionMatrix
from GEODE.Visualisation.Label import add_labels_argument, getLabels
from GEODE.Visualisation.Legend import trim
import matplotlib.pyplot as plot
import seaborn
def drawMatrix(data, outputFile, maxWidth=None, **kwargs):
ticks = trim(data['labels'], maxWidth)
plot.figure(figsize=(8,7))
ax = seaborn.heatmap(
data['matrix'], xticklabels=ticks, yticklabels=ticks, **kwargs
)
plot.savefig(prepare(outputFile), dpi=300, bbox_inches='tight')
def getArgs(arguments):
cli = argparse.ArgumentParser(
prog='drawMatrix',
description="Draw a matrix")
cli.add_argument('inputFile')
cli.add_argument('outputPNG')
add_labels_argument(cli)
cli.add_argument('-w', '--maxWidth', type=int,
help="length from which labels will be truncated")
cli.add_argument('-c', '--cmap', help="color map to use")
return cli.parse_args(arguments)
def drawMatrixCLI(arguments):
args = getArgs(arguments)
labels = getLabels(args)
data = getConfusionMatrix(args.inputFile, labels)
drawMatrix(data,
args.outputPNG,
maxWidth=args.maxWidth,
cmap=args.cmap)
from argparse import Namespace
def add_labels_argument(cli):
cli.add_argument('-l', '--labels',
help="path to a file containing one label per line")
def getLabels(args):
if type(args) == Namespace:
args = vars(args)
if type(args) == dict and 'labels' in args:
path = args['labels']
elif type(args) == str:
path = args
else:
path = None
if path is not None:
with open(path, 'r') as labelsFile:
return list(map(lambda x: x.strip(), labelsFile))
from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrix, heatmap
from GEODE.Visualisation.DensityProfile import densityProfile, drawDensityProfile, plotDensity
from GEODE.Visualisation.ConfusionMatrix import extractConfusionMatrix, \
maxConfusionMatrix
from GEODE.Visualisation.DensityProfile import densityProfile, \
drawDensityProfile, plotDensity
from GEODE.Visualisation.DrawMatrix import drawMatrix
from GEODE.Visualisation.Legend import trim as legend
......@@ -22,14 +22,17 @@ from GEODE.Classification import fullLemmas, isStopWord, superdomains as domains
from GEODE.ENE import eneLabels
from GEODE.Metadata import article, articleKey, paragraph, paragraphKey, \
fromKey, relativePath, toKey, uid
from GEODE.Store import corpus, Directory, SelfContained, tabular, toTSV
from GEODE.Visualisation import densityProfile, heatmap, legend
from GEODE.Visualisation.ConfusionMatrix import drawConfusionMatrixCLI
from GEODE.Store import corpus, Directory, JSON, SelfContained, tabular, toTSV
from GEODE.Visualisation import densityProfile, drawMatrix, legend, \
maxConfusionMatrix
from GEODE.Visualisation.ConfusionMatrix import extractConfusionMatrixCLI
from GEODE.Visualisation.DensityProfile import drawDensityProfileCLI
from GEODE.Visualisation.DrawMatrix import drawMatrixCLI
commands = {
'confusionMatrix': drawConfusionMatrixCLI,
'densityProfile': drawDensityProfileCLI
'confusionMatrix': extractConfusionMatrixCLI,
'densityProfile': drawDensityProfileCLI,
'drawMatrix': drawMatrixCLI,
}
def geopyckCLI():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment