From ef245f2984af6747059962d836b452d5f754fd91 Mon Sep 17 00:00:00 2001
From: Alice BRENON <alice.brenon@ens-lyon.fr>
Date: Fri, 15 Sep 2023 19:13:15 +0200
Subject: [PATCH] Keep reworking things, factorize source directory handling

---
 scripts/ML/Source.py  | 29 +++++++++++++++++++++++
 scripts/ML/gpu.py     | 10 --------
 scripts/ML/loaders.py | 31 +++++++++++++++++++++++++
 scripts/ML/predict.py | 53 +++++++------------------------------------
 4 files changed, 68 insertions(+), 55 deletions(-)
 create mode 100644 scripts/ML/Source.py
 delete mode 100644 scripts/ML/gpu.py
 create mode 100644 scripts/ML/loaders.py

diff --git a/scripts/ML/Source.py b/scripts/ML/Source.py
new file mode 100644
index 0000000..8800760
--- /dev/null
+++ b/scripts/ML/Source.py
@@ -0,0 +1,29 @@
+class Source:
+    """
+    A class to handle the normalised path used in the project and loading the
+    actual text input as a generator from records when they are needed
+    """
+    def __init__(self, root_path):
+        """
+        Positional arguments
+        :param root_path: the path to a GÉODE-style folder containing the text
+        version of the corpus on which to predict the classes
+        """
+        self.root_path = root_path
+
+    def path_to(self, record):
+        article_relative_path = "{work}/T{volume}/{article}".format(**record)
+        prefix = f"{self.root_path}/{article_relative_path}"
+        if 'paragraph' in record:
+            return f"{prefix}/{record.paragraph}.txt"
+        else:
+            return f"{prefix}.txt"
+
+    def load_text(self, record):
+        with open(self.path_to(record), 'r') as file:
+            return file.read()
+
+    def iterate(self, records):
+        for _, record in records.iterrows():
+            yield self.load_text(record)
+
diff --git a/scripts/ML/gpu.py b/scripts/ML/gpu.py
deleted file mode 100644
index b0b34b2..0000000
--- a/scripts/ML/gpu.py
+++ /dev/null
@@ -1,10 +0,0 @@
-import torch
-
-class WithGPU:
-    def __init__(self):
-        if torch.cuda.is_available():
-            print('We will use the GPU:', torch.cuda.get_device_name(0))
-            self.device = torch.device("cuda")
-        else:
-            print('No GPU available, using the CPU instead.')
-            self.device = torch.device("cpu")
diff --git a/scripts/ML/loaders.py b/scripts/ML/loaders.py
new file mode 100644
index 0000000..dc05ef7
--- /dev/null
+++ b/scripts/ML/loaders.py
@@ -0,0 +1,31 @@
+import os
+import pickle
+from sklearn import preprocessing
+import torch
+
+def get_device():
+    if torch.cuda.is_available():
+        print('We will use the GPU:', torch.cuda.get_device_name(0))
+        return torch.device("cuda")
+    else:
+        print('No GPU available, using the CPU instead.')
+        return torch.device("cpu")
+
+def get_encoder(root_path, create_from=None):
+    path = f"{root_path}/label_encoder.pkl"
+    if os.path.isfile(path):
+        with open(path, 'rb') as pickled:
+            return pickle.load(pickled)
+    elif create_from is not None:
+        encoder = preprocessing.LabelEncoder()
+        encoder.fit(create_from)
+        with open(path, 'wb') as file:
+            pickle.dump(encoder, file)
+        return encoder
+    else:
+        raise FileNotFoundError(path)
+
+def get_tokenizer():
+    model_name = 'bert-base-multilingual-cased'
+    print('Loading BERT tokenizer...')
+    return BertTokenizer.from_pretrained(model_name)
diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py
index 0f00ab0..5ac70b0 100644
--- a/scripts/ML/predict.py
+++ b/scripts/ML/predict.py
@@ -1,14 +1,14 @@
 #!/usr/bin/env python3
-from gpu import WithGPU
+import loaders import get_device, get_encoder, get_tokenizer
 import numpy
 import pandas
-import pickle
 import sklearn
+from Source import Source
 from sys import argv
 from tqdm import tqdm
 from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline
 
-class Classifier(WithGPU):
+class Classifier:
     """
     A class wrapping all the different models and classes used throughout a
     classification task:
@@ -22,20 +22,16 @@ class Classifier(WithGPU):
     containing the texts to classify
     """
     def __init__(self, root_path):
-        WithGPU.__init__(self)
-        self._init_tokenizer()
+        self.device = get_device()
+        self.tokenizer = get_tokenizer()
         self._init_model(root_path)
         self._init_pipe()
-        self._init_encoder(f"{root_path}/label_encoder.pkl")
+        self.encoder = get_encoder(root_path)
 
     def _init_model(self, path):
         bert = BertForSequenceClassification.from_pretrained(path)
         self.model = bert.to(self.device.type)
 
-    def _init_tokenizer(self):
-        model_name = 'bert-base-multilingual-cased'
-        self.tokenizer = BertTokenizer.from_pretrained(model_name)
-
     def _init_pipe(self):
         self.pipe = TextClassificationPipeline(
             model=self.model,
@@ -43,10 +39,6 @@ class Classifier(WithGPU):
             return_all_scores=True,
             device=self.device)
 
-    def _init_encoder(self, path):
-        with open(path, 'rb') as pickled:
-            self.encoder = pickle.load(pickled)
-
     def __call__(self, text_generator):
         tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
         predictions = []
@@ -55,37 +47,8 @@ class Classifier(WithGPU):
             predictions.append([int(byScoreDesc[0]['label'][6:]),
                                 byScoreDesc[0]['score'],
                                 int(byScoreDesc[1]['label'][6:])])
-        predictions = numpy.array(predictions)
-        return list(self.encoder.inverse_transform(predictions[:,0].astype(int)))
-
-class Source:
-    """
-    A class to handle the normalised path used in the project and loading the
-    actual text input as a generator from records when they are needed
-    """
-    def __init__(self, root_path):
-        """
-        Positional arguments
-        :param root_path: the path to a GÉODE-style folder containing the text
-        version of the corpus on which to predict the classes
-        """
-        self.root_path = root_path
-
-    def path_to(self, record):
-        article_relative_path = "{work}/T{volume}/{article}".format(**record)
-        prefix = f"{self.root_path}/{article_relative_path}"
-        if 'paragraph' in record:
-            return f"{prefix}/{record.paragraph}.txt"
-        else:
-            return f"{prefix}.txt"
-
-    def load_text(self, record):
-        with open(self.path_to(record), 'r') as file:
-            return file.read()
-
-    def iterate(self, records):
-        for _, record in records.iterrows():
-            yield self.load_text(record)
+        return self.encoder.inverse_transform(
+                numpy.array(predictions)[:,0].astype(int))
 
 def label(classify, source, tsv_path, name='label'):
     """
-- 
GitLab