From 0d65742640d42d1ee7e380c044d511f63b91b05c Mon Sep 17 00:00:00 2001
From: Alice BRENON <alice.brenon@ens-lyon.fr>
Date: Thu, 21 Sep 2023 23:36:27 +0200
Subject: [PATCH] Fix mistakes created when refactoring

---
 scripts/ML/BERT/Base.py    |  8 ++++----
 scripts/ML/BERT/Trainer.py |  4 ++++
 scripts/ML/Corpus.py       | 12 +++++++-----
 scripts/ML/predict.py      |  4 ++--
 4 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/scripts/ML/BERT/Base.py b/scripts/ML/BERT/Base.py
index c8b8d11..5c2de16 100644
--- a/scripts/ML/BERT/Base.py
+++ b/scripts/ML/BERT/Base.py
@@ -42,15 +42,15 @@ class BERT:
         print('Loading BERT tools')
         self._init_tokenizer()
         self.root_path = root_path
-        _init_classifier(training)
+        self._init_classifier(training)
 
     @loader
-    def _init_tokenizer():
+    def _init_tokenizer(self):
         self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)
 
     @loader
-    def _init_classifier(training)
-        if training
+    def _init_classifier(self, training):
+        if training:
             bert = BertForSequenceClassification.from_pretrained(
                     model_name, # Use the 12-layer BERT model, with an uncased vocab.
                     num_labels = numberOfClasses, # The number of output labels--2 for binary classification.
diff --git a/scripts/ML/BERT/Trainer.py b/scripts/ML/BERT/Trainer.py
index e69de29..60bd0e5 100644
--- a/scripts/ML/BERT/Trainer.py
+++ b/scripts/ML/BERT/Trainer.py
@@ -0,0 +1,4 @@
+from BERT.Base import BERT
+
+class Trainer(BERT):
+    pass
diff --git a/scripts/ML/Corpus.py b/scripts/ML/Corpus.py
index b81a56c..910cf33 100644
--- a/scripts/ML/Corpus.py
+++ b/scripts/ML/Corpus.py
@@ -30,6 +30,8 @@ class TSVIndexed(Corpus):
         self.tsv_path = tsv_path
         self.column_name = column_name
         self.data = None
+        self.projectors = dict((p, self.__getattribute__(p))
+                              for p in ['key', 'content', 'full'])
 
     def load(self):
         if self.data is None:
@@ -46,19 +48,19 @@ class TSVIndexed(Corpus):
     def content(self, key, row):
         pass
 
-    def keys(self, _, row):
+    def key(self, _, row):
         return row[self.keys].to_dict()
 
     def full(self, key, row):
-        d = self.keys(key, row)
+        d = self.key(key, row)
         d[self.column_name] = self.content(key, row).strip() + '\n'
         return d
 
-    def get_all(self, projector):
+    def get_all(self, projector=None):
         if projector is None:
             projector = self.full
-        elif type(projector) == str:
-            projector = self.__getattribute__(projector)
+        elif type(projector) == str and projector in self.projectors:
+            projector = self.projectors[projector]
         self.load()
         for row in self.data.iterrows():
             yield projector(*row)
diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py
index f1768db..2dfa3e8 100644
--- a/scripts/ML/predict.py
+++ b/scripts/ML/predict.py
@@ -21,8 +21,8 @@ def label(classify, source, name='label'):
     :return: a panda dataframe containing the records from the input TSV file plus
     an additional column
     """
-    records = pandas.DataFrame(source.get_all('keys'))
-    records[name] = classify(source.get_all('content')
+    records = pandas.DataFrame(source.get_all('key'))
+    records[name] = classify(source.get_all('content'))
     return records
 
 if __name__ == '__main__':
-- 
GitLab