Skip to content
Snippets Groups Projects
confusionMatrices.py 1.62 KiB
Newer Older
#!/usr/bin/env python3

from EDdA import data
from EDdA.classification import confusionMatrix, metrics, toPNG, topNGrams
import os
import sys

def preparePath(root, source, n, ranks, metricName):
    path = "{root}/confusionMatrix/{inputHash}/{n}grams_top{ranks}_{name}.png".format(
            root=root,
            inputHash=source.hash,
            n=n,
            ranks=ranks,
            name=metricName
        )
    os.makedirs(os.path.dirname(path), exist_ok=True)
    return path

def __syntax(this):
    print(
            "Syntax: {this} {required} {optional}".format(
                    this=this,
                    required="ARTICLES_DATA(.csv) OUTPUT_DIR",
                    optional="[NGRAM SIZE] [TOP_RANKS_SIZE] [METRIC_NAME]"
                ),
            file=sys.stderr
        )
    sys.exit(1)

def __compute(sourcePath, ns, ranksToTry, metricNames, outputDir):
    source = data.load(sourcePath)
    for n in ns:
        for ranks in ranksToTry:
            vectorizer = topNGrams(source, n, ranks)
            for name in metricNames:
                imagePath = preparePath(outputDir, source, n, ranks, name)
                toPNG(confusionMatrix(vectorizer, metrics[name]), imagePath)

if __name__ == '__main__':
    argc = len(sys.argv)
    if argc < 2:
        __syntax(sys.argv[0])
    else:
        sourcePath = sys.argv[1]
        outputDir = sys.argv[2]
        ns = [int(sys.argv[3])] if argc > 3 else range(1,4)
        ranksToTry = [int(sys.argv[4])] if argc > 4 else [10, 100, 50]
        metricNames = [sys.argv[5]] if argc > 5 else metrics.keys()
        __compute(sourcePath, ns, ranksToTry, metricNames, outputDir)