Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from BERT.Base import BERT
import numpy
from tqdm import tqdm
from transformers import TextClassificationPipeline
class Classifier(BERT):
"""
A class wrapping all the different models and classes used throughout a
classification task and based on BERT:
- tokenizer
- classifier
- pipeline
- label encoder
Once created, it behaves as a function which you apply to a generator
containing the texts to classify
"""
def __init__(self, root_path):
BERT.__init__(self, root_path)
self._init_pipe()
def _init_pipe(self):
self.pipe = TextClassificationPipeline(
model=self.model,
tokenizer=self.tokenizer,
return_all_scores=True,
device=self.device)
def __call__(self, text_generator):
tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
predictions = []
for output in tqdm(self.pipe(text_generator, **tokenizer_kwargs)):
byScoreDesc = sorted(output, key=lambda d: d['score'], reverse=True)
predictions.append([int(byScoreDesc[0]['label'][6:]),
byScoreDesc[0]['score'],
int(byScoreDesc[1]['label'][6:])])
return self.encoder.inverse_transform(
numpy.array(predictions)[:,0].astype(int))