Newer
Older
Alice Brenon
committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python3
import numpy
import pandas
import pickle
import sklearn
from sys import argv
import torch
from tqdm import tqdm
from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline
class Classifier:
"""
A class wrapping all the different models and classes used throughout a
classification task:
- tokenizer
- classifier
- pipeline
- label encoder
Once created, it behaves as a function which you apply to a generator
containing the texts to classify
"""
def __init__(self, root_path):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._init_tokenizer()
self._init_model(root_path)
self._init_pipe()
self._init_encoder(f"{root_path}/label_encoder.pkl")
self.log()
def _init_model(self, path):
bert = BertForSequenceClassification.from_pretrained(path)
self.model = bert.to(self.device.type)
def _init_tokenizer(self):
model_name = 'bert-base-multilingual-cased'
self.tokenizer = BertTokenizer.from_pretrained(model_name)
def _init_pipe(self):
self.pipe = TextClassificationPipeline(
model=self.model,
tokenizer=self.tokenizer,
return_all_scores=True,
device=self.device)
def _init_encoder(self, path):
with open(path, 'rb') as pickled:
self.encoder = pickle.load(pickled)
def log(self):
if self.device.type == 'cpu':
print('No GPU available, using the CPU instead.')
else:
print('We will use the GPU:', torch.cuda.get_device_name(0))
def __call__(self, text_generator):
tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
predictions = []
for output in tqdm(self.pipe(text_generator, **tokenizer_kwargs)):
byScoreDesc = sorted(output, key=lambda d: d['score'], reverse=True)
predictions.append([int(byScoreDesc[0]['label'][6:]),
byScoreDesc[0]['score'],
int(byScoreDesc[1]['label'][6:])])
predictions = numpy.array(predictions)
return list(self.encoder.inverse_transform(predictions[:,0].astype(int)))
class Source:
"""
A class to handle the normalised path used in the project and loading the
actual text input as a generator from records when they are needed
"""
def __init__(self, root_path):
"""
Positional arguments
:param root_path: the path to a GÉODE-style folder containing the text
version of the corpus on which to predict the classes
"""
self.root_path = root_path
def path_to(self, record):
article_relative_path = "{work}/T{volume}/{article}".format(**record)
prefix = f"{self.root_path}/{article_relative_path}"
if 'paragraph' in record:
return f"{prefix}/{record.paragraph}.txt"
else:
return f"{prefix}.txt"
def load_text(self, record):
with open(self.path_to(record), 'r') as file:
return file.read()
def iterate(self, records):
for _, record in records.iterrows():
yield self.load_text(record)
def label(classify, source, tsv_path, name='label'):
"""
Make predictions on a set of document
Positional arguments
:param classify: an instance of the Classifier class above
:param source: an instance of the Source class above
:param tsv_path: the path to a TSV file containing (at least) article or
paragraph records (additional metadata will be ignored)
Keyword arguments
:param name: defaults to 'label' — the name of the column to be created, that is
to say, the name of the category you are predicting with your model (if your
model labels in "Red", "Green", or "Blue", you may want to use
`name='color'`).
:return: a panda dataframe containing the records from the input TSV file plus
an additional column
"""
records = pandas.read_csv(tsv_path, sep='\t')
records[name] = classify(source.iterate(records))
return records
if __name__ == '__main__':
classify = Classifier(argv[1])
source = Source(argv[2])
label(classify, source, argv[3]).to_csv(argv[4], sep='\t', index=False)