From 7ccb5f7eda682ee38382cece9f277632dc857e14 Mon Sep 17 00:00:00 2001
From: Alice BRENON <alice.brenon@ens-lyon.fr>
Date: Wed, 22 Nov 2023 17:31:11 +0100
Subject: [PATCH] Add a script to generate reports and draw a confusion matrix

---
 scripts/ML/evaluate.py | 27 +++++++++++++++++++++++++++
 1 file changed, 27 insertions(+)
 create mode 100755 scripts/ML/evaluate.py

diff --git a/scripts/ML/evaluate.py b/scripts/ML/evaluate.py
new file mode 100755
index 0000000..104cdad
--- /dev/null
+++ b/scripts/ML/evaluate.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+from EDdA.classification import heatmap
+from EDdA.store import preparePath
+import GEODE.discursive as discursive
+import pandas
+from sklearn.metrics import classification_report, confusion_matrix
+from sys import argv
+
+def evaluate(truth, predictions, outputDirectory):
+    matrix = confusion_matrix(truth,
+                              predictions,
+                              labels=list(discursive.functions),
+                              normalize='true')
+    heatmap(matrix,
+            preparePath(f"{outputDirectory}/confusion.png"),
+            labels=discursive.functions)
+    with open(f"{outputDirectory}/report.json", 'w') as json:
+        print(classification_report(truth, predictions, output_dict=True),
+              file=json)
+    with open(f"{outputDirectory}/report.txt", 'w') as txt:
+        print(classification_report(truth, predictions),
+              file=txt)
+
+if __name__ == '__main__':
+    truth = pandas.read_csv(argv[1], sep='\t')
+    predictions = pandas.read_csv(argv[2], sep='\t')
+    evaluate(truth['paragraphFunction'], predictions['label'], argv[3])
-- 
GitLab