diff --git a/scripts/modelConfusionMatrix.py b/scripts/modelConfusionMatrix.py new file mode 100755 index 0000000000000000000000000000000000000000..7bd3eb30a3422132f2456146d065f26b89f46574 --- /dev/null +++ b/scripts/modelConfusionMatrix.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +from EDdA.store import preparePath +import pandas +from sklearn.metrics import confusion_matrix +from EDdA.classification import heatmap +import sys + +def modelConfusionMatrix(): + if len(sys.argv) == 3: + input_csv = sys.argv[1] + output_png = sys.argv[2] + model = pandas.read_csv(input_csv, index_col=0) + labels, predictions = model['labels'], model['predictions'] + confusion = confusion_matrix(labels, predictions, normalize='true') + heatmap(confusion, preparePath(output_png)) + else: + print(f"Syntax: {sys.argv[0]} INPUT_CSV OUTPUT_PNG") + +if __name__ == '__main__': + modelConfusionMatrix() diff --git a/setup.py b/setup.py index 1d445ec2e3da069a9a7345f003b3d3c2fc2f3d0e..7ca1db9b399b5e523421b0128ae279e3c115fc85 100644 --- a/setup.py +++ b/setup.py @@ -4,4 +4,5 @@ from setuptools import setup setup(name='PyEDdA', version='0.1', - packages=['EDdA', 'EDdA.classification']) + packages=['EDdA', 'EDdA.classification'], + scripts=['scripts/modelConfusionMatrix.py'])