from BERT.Base import BERT import numpy from tqdm import tqdm from transformers import TextClassificationPipeline class Classifier(BERT): """ A class wrapping all the different models and classes used throughout a classification task and based on BERT: - tokenizer - classifier - pipeline - label encoder Once created, it behaves as a function which you apply to a generator containing the texts to classify """ def __init__(self, root_path): BERT.__init__(self, root_path) self._init_pipe() def _init_pipe(self): self.pipe = TextClassificationPipeline( model=self.model, tokenizer=self.tokenizer, return_all_scores=True, device=self.device) def __call__(self, text_generator): tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512} predictions = [] for output in tqdm(self.pipe(text_generator, **tokenizer_kwargs)): byScoreDesc = sorted(output, key=lambda d: d['score'], reverse=True) predictions.append([int(byScoreDesc[0]['label'][6:]), byScoreDesc[0]['score'], int(byScoreDesc[1]['label'][6:])]) return self.encoder.inverse_transform( numpy.array(predictions)[:,0].astype(int))