Skip to content
Snippets Groups Projects
Commit 774ad89b authored by Alice Brenon's avatar Alice Brenon
Browse files

Factorize BERT components into a subclass

parent ef245f29
No related branches found
No related tags found
No related merge requests found
from loaders import get_device
class BERT:
model_name = 'bert-base-multilingual-cased'
def __init(self, path):
print('Loading BERT tools…')
self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)
print('✔️ tokenizer')
self.device = get_device()
bert = BertForSequenceClassification.from_pretrained(path)
self.model = bert.to(self.device.type)
print('✔️ classifier')
...@@ -25,7 +25,9 @@ def get_encoder(root_path, create_from=None): ...@@ -25,7 +25,9 @@ def get_encoder(root_path, create_from=None):
else: else:
raise FileNotFoundError(path) raise FileNotFoundError(path)
def get_tokenizer(): def set_random():
model_name = 'bert-base-multilingual-cased' seed_value = 42
print('Loading BERT tokenizer...') random.seed(seed_val)
return BertTokenizer.from_pretrained(model_name) np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
#!/usr/bin/env python3 #!/usr/bin/env python3
import loaders import get_device, get_encoder, get_tokenizer from BERT import BERT
from loaders import get_encoder
import numpy import numpy
import pandas import pandas
import sklearn import sklearn
...@@ -8,10 +9,10 @@ from sys import argv ...@@ -8,10 +9,10 @@ from sys import argv
from tqdm import tqdm from tqdm import tqdm
from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline
class Classifier: class Classifier(BERT):
""" """
A class wrapping all the different models and classes used throughout a A class wrapping all the different models and classes used throughout a
classification task: classification task and based on BERT:
- tokenizer - tokenizer
- classifier - classifier
...@@ -22,16 +23,10 @@ class Classifier: ...@@ -22,16 +23,10 @@ class Classifier:
containing the texts to classify containing the texts to classify
""" """
def __init__(self, root_path): def __init__(self, root_path):
self.device = get_device() BERT.__init__(self, root_path)
self.tokenizer = get_tokenizer()
self._init_model(root_path)
self._init_pipe() self._init_pipe()
self.encoder = get_encoder(root_path) self.encoder = get_encoder(root_path)
def _init_model(self, path):
bert = BertForSequenceClassification.from_pretrained(path)
self.model = bert.to(self.device.type)
def _init_pipe(self): def _init_pipe(self):
self.pipe = TextClassificationPipeline( self.pipe = TextClassificationPipeline(
model=self.model, model=self.model,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment