From 774ad89ba7c7ff198f5268d42299afd3e95d66ff Mon Sep 17 00:00:00 2001
From: Alice BRENON <alice.brenon@ens-lyon.fr>
Date: Tue, 19 Sep 2023 12:12:07 +0200
Subject: [PATCH] Factorize BERT components into a subclass

---
 scripts/ML/BERT.py    | 12 ++++++++++++
 scripts/ML/loaders.py | 10 ++++++----
 scripts/ML/predict.py | 15 +++++----------
 3 files changed, 23 insertions(+), 14 deletions(-)
 create mode 100644 scripts/ML/BERT.py

diff --git a/scripts/ML/BERT.py b/scripts/ML/BERT.py
new file mode 100644
index 0000000..2159179
--- /dev/null
+++ b/scripts/ML/BERT.py
@@ -0,0 +1,12 @@
+from loaders import get_device
+
+class BERT:
+    model_name = 'bert-base-multilingual-cased'
+    def __init(self, path):
+        print('Loading BERT tools…')
+        self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)
+        print('✔️ tokenizer')
+        self.device = get_device()
+        bert = BertForSequenceClassification.from_pretrained(path)
+        self.model = bert.to(self.device.type)
+        print('✔️ classifier')
diff --git a/scripts/ML/loaders.py b/scripts/ML/loaders.py
index dc05ef7..859669d 100644
--- a/scripts/ML/loaders.py
+++ b/scripts/ML/loaders.py
@@ -25,7 +25,9 @@ def get_encoder(root_path, create_from=None):
     else:
         raise FileNotFoundError(path)
 
-def get_tokenizer():
-    model_name = 'bert-base-multilingual-cased'
-    print('Loading BERT tokenizer...')
-    return BertTokenizer.from_pretrained(model_name)
+def set_random():
+    seed_value = 42
+    random.seed(seed_val)
+    np.random.seed(seed_val)
+    torch.manual_seed(seed_val)
+    torch.cuda.manual_seed_all(seed_val)
diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py
index 5ac70b0..f6bdba4 100644
--- a/scripts/ML/predict.py
+++ b/scripts/ML/predict.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
-import loaders import get_device, get_encoder, get_tokenizer
+from BERT import BERT
+from loaders import get_encoder
 import numpy
 import pandas
 import sklearn
@@ -8,10 +9,10 @@ from sys import argv
 from tqdm import tqdm
 from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline
 
-class Classifier:
+class Classifier(BERT):
     """
     A class wrapping all the different models and classes used throughout a
-    classification task:
+    classification task and based on BERT:
 
         - tokenizer
         - classifier
@@ -22,16 +23,10 @@ class Classifier:
     containing the texts to classify
     """
     def __init__(self, root_path):
-        self.device = get_device()
-        self.tokenizer = get_tokenizer()
-        self._init_model(root_path)
+        BERT.__init__(self, root_path)
         self._init_pipe()
         self.encoder = get_encoder(root_path)
 
-    def _init_model(self, path):
-        bert = BertForSequenceClassification.from_pretrained(path)
-        self.model = bert.to(self.device.type)
-
     def _init_pipe(self):
         self.pipe = TextClassificationPipeline(
             model=self.model,
-- 
GitLab