from BERT.Base import BERT
import numpy
from tqdm import tqdm
from transformers import TextClassificationPipeline

class Classifier(BERT):
    """
    A class wrapping all the different models and classes used throughout a
    classification task and based on BERT:

        - tokenizer
        - classifier
        - pipeline
        - label encoder

    Once created, it behaves as a function which you apply to a generator
    containing the texts to classify
    """
    def __init__(self, root_path):
        BERT.__init__(self, root_path)
        self._init_pipe()

    def _init_pipe(self):
        self.pipe = TextClassificationPipeline(
            model=self.model,
            tokenizer=self.tokenizer,
            top_k=1,
            device=self.device)

    def __call__(self, text_generator):
        tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
        labels, scores = [], []
        for output in tqdm(self.pipe(text_generator, **tokenizer_kwargs)):
            byScoreDesc = sorted(output, key=lambda d: d['score'], reverse=True)
            labels.append(int(byScoreDesc[0]['label'][6:]))
            scores.append(byScoreDesc[0]['score'])
        return self.encoder.inverse_transform(labels), numpy.array(scores)