From 4b1f435b948e276addb5e202a2d071fa10c33fe3 Mon Sep 17 00:00:00 2001 From: Alice BRENON <alice.brenon@ens-lyon.fr> Date: Mon, 6 Nov 2023 10:54:24 +0100 Subject: [PATCH] Add script to generate confusion matrices --- scripts/modelConfusionMatrix.py | 21 +++++++++++++++++++++ setup.py | 3 ++- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100755 scripts/modelConfusionMatrix.py diff --git a/scripts/modelConfusionMatrix.py b/scripts/modelConfusionMatrix.py new file mode 100755 index 0000000..7bd3eb3 --- /dev/null +++ b/scripts/modelConfusionMatrix.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +from EDdA.store import preparePath +import pandas +from sklearn.metrics import confusion_matrix +from EDdA.classification import heatmap +import sys + +def modelConfusionMatrix(): + if len(sys.argv) == 3: + input_csv = sys.argv[1] + output_png = sys.argv[2] + model = pandas.read_csv(input_csv, index_col=0) + labels, predictions = model['labels'], model['predictions'] + confusion = confusion_matrix(labels, predictions, normalize='true') + heatmap(confusion, preparePath(output_png)) + else: + print(f"Syntax: {sys.argv[0]} INPUT_CSV OUTPUT_PNG") + +if __name__ == '__main__': + modelConfusionMatrix() diff --git a/setup.py b/setup.py index 1d445ec..7ca1db9 100644 --- a/setup.py +++ b/setup.py @@ -4,4 +4,5 @@ from setuptools import setup setup(name='PyEDdA', version='0.1', - packages=['EDdA', 'EDdA.classification']) + packages=['EDdA', 'EDdA.classification'], + scripts=['scripts/modelConfusionMatrix.py']) -- GitLab