Skip to content
Snippets Groups Projects
DrawMatrix.py 1.28 KiB
import argparse
from GEODE.Store import prepare
from GEODE.Visualisation.ConfusionMatrix import getConfusionMatrix
from GEODE.Visualisation.Label import add_labels_argument, getLabels
from GEODE.Visualisation.Legend import trim
import matplotlib.pyplot as plot
import seaborn

def drawMatrix(data, outputFile, maxWidth=None, **kwargs):
    ticks = trim(data['labels'], maxWidth)
    plot.figure(figsize=(8,7))
    ax = seaborn.heatmap(
            data['matrix'], xticklabels=ticks, yticklabels=ticks, **kwargs
        )
    plot.savefig(prepare(outputFile), dpi=300, bbox_inches='tight')

def getArgs(arguments):
    cli = argparse.ArgumentParser(
            prog='drawMatrix',
            description="Draw a matrix")
    cli.add_argument('inputFile')
    cli.add_argument('outputPNG')
    add_labels_argument(cli)
    cli.add_argument('-w', '--maxWidth', type=int,
                     help="length from which labels will be truncated")
    cli.add_argument('-c', '--cmap', help="color map to use")
    return cli.parse_args(arguments)

def drawMatrixCLI(arguments):
    args = getArgs(arguments)
    labels = getLabels(args)
    data = getConfusionMatrix(args.inputFile, labels)
    drawMatrix(data,
               args.outputPNG,
               maxWidth=args.maxWidth,
               cmap=args.cmap)