from BERT.Base import BERT import numpy from tqdm import tqdm from transformers import TextClassificationPipeline class Classifier(BERT): """ A class wrapping all the different models and classes used throughout a classification task and based on BERT: - tokenizer - classifier - pipeline - label encoder Once created, it behaves as a function which you apply to a generator containing the texts to classify """ def __init__(self, root_path): BERT.__init__(self, root_path) self._init_pipe() def _init_pipe(self): self.pipe = TextClassificationPipeline( model=self.model, tokenizer=self.tokenizer, top_k=1, device=self.device) def __call__(self, text_generator): tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512} labels, scores = [], [] for output in tqdm(self.pipe(text_generator, **tokenizer_kwargs)): byScoreDesc = sorted(output, key=lambda d: d['score'], reverse=True) labels.append(int(byScoreDesc[0]['label'][6:])) scores.append(byScoreDesc[0]['score']) return self.encoder.inverse_transform(labels), numpy.array(scores)