from BERT.Base import BERT
import numpy
from tqdm import tqdm
from transformers import TextClassificationPipeline

class Classifier(BERT):
    """
    A class wrapping all the different models and classes used throughout a
    classification task and based on BERT:

        - tokenizer
        - classifier
        - pipeline
        - label encoder

    Once created, it behaves as a function which you apply to a generator
    containing the texts to classify
    """
    def __init__(self, root_path):
        BERT.__init__(self, root_path)
        self._init_pipe()

    def _init_pipe(self):
        self.pipe = TextClassificationPipeline(
            model=self.model,
            tokenizer=self.tokenizer,
            return_all_scores=True,
            device=self.device)

    def __call__(self, text_generator):
        tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
        predictions = []
        for output in tqdm(self.pipe(text_generator, **tokenizer_kwargs)):
            byScoreDesc = sorted(output, key=lambda d: d['score'], reverse=True)
            predictions.append([int(byScoreDesc[0]['label'][6:]),
                                byScoreDesc[0]['score'],
                                int(byScoreDesc[1]['label'][6:])])
        return self.encoder.inverse_transform(
                numpy.array(predictions)[:,0].astype(int))