diff --git a/scripts/ML/BERT.py b/scripts/ML/BERT.py new file mode 100644 index 0000000000000000000000000000000000000000..2159179af3ab2c063555712f6b171ab37446ed2a --- /dev/null +++ b/scripts/ML/BERT.py @@ -0,0 +1,12 @@ +from loaders import get_device + +class BERT: + model_name = 'bert-base-multilingual-cased' + def __init(self, path): + print('Loading BERT tools…') + self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name) + print('âœ”ï¸ tokenizer') + self.device = get_device() + bert = BertForSequenceClassification.from_pretrained(path) + self.model = bert.to(self.device.type) + print('âœ”ï¸ classifier') diff --git a/scripts/ML/loaders.py b/scripts/ML/loaders.py index dc05ef7e5b1ce7c93bc72d70dfb9137f25009e75..859669d42884c62c4b0f62f77e5bf852eb4829e7 100644 --- a/scripts/ML/loaders.py +++ b/scripts/ML/loaders.py @@ -25,7 +25,9 @@ def get_encoder(root_path, create_from=None): else: raise FileNotFoundError(path) -def get_tokenizer(): - model_name = 'bert-base-multilingual-cased' - print('Loading BERT tokenizer...') - return BertTokenizer.from_pretrained(model_name) +def set_random(): + seed_value = 42 + random.seed(seed_val) + np.random.seed(seed_val) + torch.manual_seed(seed_val) + torch.cuda.manual_seed_all(seed_val) diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py index 5ac70b06b1a482a058a9c6286a2fa1264659fd70..f6bdba4d2f62fb4c96ddef42b9fe8b709da4a6b9 100644 --- a/scripts/ML/predict.py +++ b/scripts/ML/predict.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -import loaders import get_device, get_encoder, get_tokenizer +from BERT import BERT +from loaders import get_encoder import numpy import pandas import sklearn @@ -8,10 +9,10 @@ from sys import argv from tqdm import tqdm from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline -class Classifier: +class Classifier(BERT): """ A class wrapping all the different models and classes used throughout a - classification task: + classification task and based on BERT: - tokenizer - classifier @@ -22,16 +23,10 @@ class Classifier: containing the texts to classify """ def __init__(self, root_path): - self.device = get_device() - self.tokenizer = get_tokenizer() - self._init_model(root_path) + BERT.__init__(self, root_path) self._init_pipe() self.encoder = get_encoder(root_path) - def _init_model(self, path): - bert = BertForSequenceClassification.from_pretrained(path) - self.model = bert.to(self.device.type) - def _init_pipe(self): self.pipe = TextClassificationPipeline( model=self.model,