From 4b1f435b948e276addb5e202a2d071fa10c33fe3 Mon Sep 17 00:00:00 2001
From: Alice BRENON <alice.brenon@ens-lyon.fr>
Date: Mon, 6 Nov 2023 10:54:24 +0100
Subject: [PATCH] Add script to generate confusion matrices

---
 scripts/modelConfusionMatrix.py | 21 +++++++++++++++++++++
 setup.py                        |  3 ++-
 2 files changed, 23 insertions(+), 1 deletion(-)
 create mode 100755 scripts/modelConfusionMatrix.py

diff --git a/scripts/modelConfusionMatrix.py b/scripts/modelConfusionMatrix.py
new file mode 100755
index 0000000..7bd3eb3
--- /dev/null
+++ b/scripts/modelConfusionMatrix.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python3
+
+from EDdA.store import preparePath
+import pandas
+from sklearn.metrics import confusion_matrix
+from EDdA.classification import heatmap
+import sys
+
+def modelConfusionMatrix():
+    if len(sys.argv) == 3:
+        input_csv = sys.argv[1]
+        output_png = sys.argv[2]
+        model = pandas.read_csv(input_csv, index_col=0)
+        labels, predictions = model['labels'], model['predictions']
+        confusion = confusion_matrix(labels, predictions, normalize='true')
+        heatmap(confusion, preparePath(output_png))
+    else:
+        print(f"Syntax: {sys.argv[0]} INPUT_CSV OUTPUT_PNG")
+
+if __name__ == '__main__':
+    modelConfusionMatrix()
diff --git a/setup.py b/setup.py
index 1d445ec..7ca1db9 100644
--- a/setup.py
+++ b/setup.py
@@ -4,4 +4,5 @@ from setuptools import setup
 
 setup(name='PyEDdA',
       version='0.1',
-      packages=['EDdA', 'EDdA.classification'])
+      packages=['EDdA', 'EDdA.classification'],
+      scripts=['scripts/modelConfusionMatrix.py'])
-- 
GitLab