From cbb53ca39d1d32831e6471718663ae54df821b54 Mon Sep 17 00:00:00 2001
From: Alice BRENON <alice.brenon@ens-lyon.fr>
Date: Wed, 22 Nov 2023 17:35:20 +0100
Subject: [PATCH] Add a script to retrieve a Simple train set from a Multi one

---
 scripts/ML/Corpus.py             |  5 +++--
 scripts/ML/simpleTrainOfMulti.py | 26 ++++++++++++++++++++++++++
 2 files changed, 29 insertions(+), 2 deletions(-)
 create mode 100755 scripts/ML/simpleTrainOfMulti.py

diff --git a/scripts/ML/Corpus.py b/scripts/ML/Corpus.py
index 00d5d9f..5abf5b3 100644
--- a/scripts/ML/Corpus.py
+++ b/scripts/ML/Corpus.py
@@ -57,14 +57,15 @@ class TSVIndexed(Corpus):
         d[self.column_name] = self.content(key, row).strip() + '\n'
         return d
 
-    def get_all(self, projector=None):
+    def get_all(self, projector=None, where=None):
         if projector is None:
             projector = self.full
         elif type(projector) == str and projector in self.projectors:
             projector = self.__getattribute__(projector)
         self.load()
         for row in self.data.iterrows():
-            yield projector(*row)
+            if where is None or where(*row):
+                yield projector(*row)
 
 class SelfContained(TSVIndexed):
     """
diff --git a/scripts/ML/simpleTrainOfMulti.py b/scripts/ML/simpleTrainOfMulti.py
new file mode 100755
index 0000000..5f80001
--- /dev/null
+++ b/scripts/ML/simpleTrainOfMulti.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+
+from Corpus import Directory, SelfContained
+from GEODE import fromKey, toKey
+import GEODE.discursive as discursive
+from prodigyAcceptedJSONLToTSV import acceptedToTSV
+from sys import argv
+
+def isAccepted(key, row):
+    return row['answer'] == 'accept'
+
+def withLabel(corpus, label):
+    return lambda key, row: dict(**corpus.full(key, row)
+                                 , paragraphFunction=label)
+
+def simpleTrainOfMulti(multiDirectory, outputTSV):
+    annotations = []
+    for className in discursive.functions:
+        corpus = Directory(multiDirectory, tsv_filename=className)
+        p = withLabel(corpus, className)
+        annotations += list(corpus.get_all(projector=p, where=isAccepted))
+    output = SelfContained(outputTSV)
+    output.save(sorted(annotations, key=toKey))
+
+if __name__ == '__main__':
+    simpleTrainOfMulti(argv[1], argv[2])
-- 
GitLab