Skip to content
Snippets Groups Projects
Commit 774ad89b authored by Alice Brenon's avatar Alice Brenon
Browse files

Factorize BERT components into a subclass

parent ef245f29
No related branches found
No related tags found
No related merge requests found
from loaders import get_device
class BERT:
model_name = 'bert-base-multilingual-cased'
def __init(self, path):
print('Loading BERT tools…')
self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)
print('✔️ tokenizer')
self.device = get_device()
bert = BertForSequenceClassification.from_pretrained(path)
self.model = bert.to(self.device.type)
print('✔️ classifier')
......@@ -25,7 +25,9 @@ def get_encoder(root_path, create_from=None):
else:
raise FileNotFoundError(path)
def get_tokenizer():
model_name = 'bert-base-multilingual-cased'
print('Loading BERT tokenizer...')
return BertTokenizer.from_pretrained(model_name)
def set_random():
seed_value = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
#!/usr/bin/env python3
import loaders import get_device, get_encoder, get_tokenizer
from BERT import BERT
from loaders import get_encoder
import numpy
import pandas
import sklearn
......@@ -8,10 +9,10 @@ from sys import argv
from tqdm import tqdm
from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline
class Classifier:
class Classifier(BERT):
"""
A class wrapping all the different models and classes used throughout a
classification task:
classification task and based on BERT:
- tokenizer
- classifier
......@@ -22,16 +23,10 @@ class Classifier:
containing the texts to classify
"""
def __init__(self, root_path):
self.device = get_device()
self.tokenizer = get_tokenizer()
self._init_model(root_path)
BERT.__init__(self, root_path)
self._init_pipe()
self.encoder = get_encoder(root_path)
def _init_model(self, path):
bert = BertForSequenceClassification.from_pretrained(path)
self.model = bert.to(self.device.type)
def _init_pipe(self):
self.pipe = TextClassificationPipeline(
model=self.model,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment