From 4fce5770861083a23bef05a1ab9a161261806fff Mon Sep 17 00:00:00 2001
From: Alice BRENON <alice.brenon@ens-lyon.fr>
Date: Wed, 20 Sep 2023 23:27:30 +0200
Subject: [PATCH] Put format autodector with the classes themselves to allow
 reuse

---
 scripts/ML/Corpus.py         | 13 +++++++++++--
 scripts/ML/convert-corpus.py | 13 +++----------
 2 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/scripts/ML/Corpus.py b/scripts/ML/Corpus.py
index b01a7c8..d2ea16d 100644
--- a/scripts/ML/Corpus.py
+++ b/scripts/ML/Corpus.py
@@ -1,5 +1,6 @@
 import pandas
-import os
+from os import makedirs
+from os.path import dirname, isdir, isfile
 
 def abstract(f):
     def wrapped(*args, **kwargs):
@@ -128,7 +129,7 @@ class Directory(TSVIndexed):
 
     def write_text(self, primary_key, content):
         path = self.path_to(primary_key)
-        os.makedirs(os.path.dirname(path), exist_ok=True)
+        makedirs(dirname(path), exist_ok=True)
         with open(path, 'w') as file:
             file.write(content)
 
@@ -138,3 +139,11 @@ class Directory(TSVIndexed):
         for _, row in self.data.iterrows():
             self.write_text(row, row[self.column_name])
         self.data[self.keys].to_csv(self.tsv_path, sep='\t', index=False)
+
+def corpus(path):
+    if path[-1:] == '/':
+        return Directory(path)
+    elif path[-4:] == '.tsv':
+        return SelfContained(path)
+    else:
+        raise FileNotFoundError(path)
diff --git a/scripts/ML/convert-corpus.py b/scripts/ML/convert-corpus.py
index 98135d4..a37fb2c 100755
--- a/scripts/ML/convert-corpus.py
+++ b/scripts/ML/convert-corpus.py
@@ -1,15 +1,8 @@
 #!/usr/bin/env python3
-import Corpus
-from os.path import isdir
+from Corpus import corpus
 import sys
 
-def detect(path):
-    if isdir(path):
-        return Corpus.Directory(path)
-    else:
-        return Corpus.SelfContained(path)
-
 if __name__ == '__main__':
-    source = detect(sys.argv[1])
-    destination = detect(sys.argv[2])
+    source = corpus(sys.argv[1])
+    destination = corpus(sys.argv[2])
     destination.save(source.get_all())
-- 
GitLab