Skip to content
Snippets Groups Projects
Classifier.py 1.35 KiB
Newer Older
Alice Brenon's avatar
Alice Brenon committed
from BERT.Base import BERT
import numpy
from tqdm import tqdm
from transformers import TextClassificationPipeline

class Classifier(BERT):
    """
    A class wrapping all the different models and classes used throughout a
    classification task and based on BERT:

        - tokenizer
        - classifier
        - pipeline
        - label encoder

    Once created, it behaves as a function which you apply to a generator
    containing the texts to classify
    """
    def __init__(self, root_path):
        BERT.__init__(self, root_path)
        self._init_pipe()

    def _init_pipe(self):
        self.pipe = TextClassificationPipeline(
            model=self.model,
            tokenizer=self.tokenizer,
            return_all_scores=True,
            device=self.device)

    def __call__(self, text_generator):
        tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
        predictions = []
        for output in tqdm(self.pipe(text_generator, **tokenizer_kwargs)):
            byScoreDesc = sorted(output, key=lambda d: d['score'], reverse=True)
            predictions.append([int(byScoreDesc[0]['label'][6:]),
                                byScoreDesc[0]['score'],
                                int(byScoreDesc[1]['label'][6:])])
        return self.encoder.inverse_transform(
                numpy.array(predictions)[:,0].astype(int))