Newer
Older
from BERT.Base import BERT
import numpy
from tqdm import tqdm
from transformers import TextClassificationPipeline
class Classifier(BERT):
"""
A class wrapping all the different models and classes used throughout a
classification task and based on BERT:
- tokenizer
- classifier
- pipeline
- label encoder
Once created, it behaves as a function which you apply to a generator
containing the texts to classify
"""
def __init__(self, root_path):
BERT.__init__(self, root_path)
self._init_pipe()
def _init_pipe(self):
self.pipe = TextClassificationPipeline(
model=self.model,
tokenizer=self.tokenizer,
Alice Brenon
committed
top_k=1,
device=self.device)
def __call__(self, text_generator):
tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
Alice Brenon
committed
labels, scores = [], []
for output in tqdm(self.pipe(text_generator, **tokenizer_kwargs)):
byScoreDesc = sorted(output, key=lambda d: d['score'], reverse=True)
Alice Brenon
committed
labels.append(int(byScoreDesc[0]['label'][6:]))
scores.append(byScoreDesc[0]['score'])
return self.encoder.inverse_transform(labels), numpy.array(scores)