From 774ad89ba7c7ff198f5268d42299afd3e95d66ff Mon Sep 17 00:00:00 2001 From: Alice BRENON <alice.brenon@ens-lyon.fr> Date: Tue, 19 Sep 2023 12:12:07 +0200 Subject: [PATCH] Factorize BERT components into a subclass --- scripts/ML/BERT.py | 12 ++++++++++++ scripts/ML/loaders.py | 10 ++++++---- scripts/ML/predict.py | 15 +++++---------- 3 files changed, 23 insertions(+), 14 deletions(-) create mode 100644 scripts/ML/BERT.py diff --git a/scripts/ML/BERT.py b/scripts/ML/BERT.py new file mode 100644 index 0000000..2159179 --- /dev/null +++ b/scripts/ML/BERT.py @@ -0,0 +1,12 @@ +from loaders import get_device + +class BERT: + model_name = 'bert-base-multilingual-cased' + def __init(self, path): + print('Loading BERT tools…') + self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name) + print('âœ”ï¸ tokenizer') + self.device = get_device() + bert = BertForSequenceClassification.from_pretrained(path) + self.model = bert.to(self.device.type) + print('âœ”ï¸ classifier') diff --git a/scripts/ML/loaders.py b/scripts/ML/loaders.py index dc05ef7..859669d 100644 --- a/scripts/ML/loaders.py +++ b/scripts/ML/loaders.py @@ -25,7 +25,9 @@ def get_encoder(root_path, create_from=None): else: raise FileNotFoundError(path) -def get_tokenizer(): - model_name = 'bert-base-multilingual-cased' - print('Loading BERT tokenizer...') - return BertTokenizer.from_pretrained(model_name) +def set_random(): + seed_value = 42 + random.seed(seed_val) + np.random.seed(seed_val) + torch.manual_seed(seed_val) + torch.cuda.manual_seed_all(seed_val) diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py index 5ac70b0..f6bdba4 100644 --- a/scripts/ML/predict.py +++ b/scripts/ML/predict.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -import loaders import get_device, get_encoder, get_tokenizer +from BERT import BERT +from loaders import get_encoder import numpy import pandas import sklearn @@ -8,10 +9,10 @@ from sys import argv from tqdm import tqdm from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline -class Classifier: +class Classifier(BERT): """ A class wrapping all the different models and classes used throughout a - classification task: + classification task and based on BERT: - tokenizer - classifier @@ -22,16 +23,10 @@ class Classifier: containing the texts to classify """ def __init__(self, root_path): - self.device = get_device() - self.tokenizer = get_tokenizer() - self._init_model(root_path) + BERT.__init__(self, root_path) self._init_pipe() self.encoder = get_encoder(root_path) - def _init_model(self, path): - bert = BertForSequenceClassification.from_pretrained(path) - self.model = bert.to(self.device.type) - def _init_pipe(self): self.pipe = TextClassificationPipeline( model=self.model, -- GitLab