diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..e95b3052ebd5de93489df96490b9df10fe641f74 --- /dev/null +++ b/scripts/ML/predict.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +import numpy +import pandas +import pickle +import sklearn +from sys import argv +import torch +from tqdm import tqdm +from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline + +class Classifier: + """ + A class wrapping all the different models and classes used throughout a + classification task: + + - tokenizer + - classifier + - pipeline + - label encoder + + Once created, it behaves as a function which you apply to a generator + containing the texts to classify + """ + def __init__(self, root_path): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._init_tokenizer() + self._init_model(root_path) + self._init_pipe() + self._init_encoder(f"{root_path}/label_encoder.pkl") + self.log() + + def _init_model(self, path): + bert = BertForSequenceClassification.from_pretrained(path) + self.model = bert.to(self.device.type) + + def _init_tokenizer(self): + model_name = 'bert-base-multilingual-cased' + self.tokenizer = BertTokenizer.from_pretrained(model_name) + + def _init_pipe(self): + self.pipe = TextClassificationPipeline( + model=self.model, + tokenizer=self.tokenizer, + return_all_scores=True, + device=self.device) + + def _init_encoder(self, path): + with open(path, 'rb') as pickled: + self.encoder = pickle.load(pickled) + + def log(self): + if self.device.type == 'cpu': + print('No GPU available, using the CPU instead.') + else: + print('We will use the GPU:', torch.cuda.get_device_name(0)) + + def __call__(self, text_generator): + tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512} + predictions = [] + for output in tqdm(self.pipe(text_generator, **tokenizer_kwargs)): + byScoreDesc = sorted(output, key=lambda d: d['score'], reverse=True) + predictions.append([int(byScoreDesc[0]['label'][6:]), + byScoreDesc[0]['score'], + int(byScoreDesc[1]['label'][6:])]) + predictions = numpy.array(predictions) + return list(self.encoder.inverse_transform(predictions[:,0].astype(int))) + +class Source: + """ + A class to handle the normalised path used in the project and loading the + actual text input as a generator from records when they are needed + """ + def __init__(self, root_path): + """ + Positional arguments + :param root_path: the path to a GÉODE-style folder containing the text + version of the corpus on which to predict the classes + """ + self.root_path = root_path + + def path_to(self, record): + article_relative_path = "{work}/T{volume}/{article}".format(**record) + prefix = f"{self.root_path}/{article_relative_path}" + if 'paragraph' in record: + return f"{prefix}/{record.paragraph}.txt" + else: + return f"{prefix}.txt" + + def load_text(self, record): + with open(self.path_to(record), 'r') as file: + return file.read() + + def iterate(self, records): + for _, record in records.iterrows(): + yield self.load_text(record) + +def label(classify, source, tsv_path, name='label'): + """ + Make predictions on a set of document + + Positional arguments + :param classify: an instance of the Classifier class above + :param source: an instance of the Source class above + :param tsv_path: the path to a TSV file containing (at least) article or + paragraph records (additional metadata will be ignored) + + Keyword arguments + :param name: defaults to 'label' — the name of the column to be created, that is + to say, the name of the category you are predicting with your model (if your + model labels in "Red", "Green", or "Blue", you may want to use + `name='color'`). + + :return: a panda dataframe containing the records from the input TSV file plus + an additional column + """ + records = pandas.read_csv(tsv_path, sep='\t') + records[name] = classify(source.iterate(records)) + return records + +if __name__ == '__main__': + classify = Classifier(argv[1]) + source = Source(argv[2]) + label(classify, source, argv[3]).to_csv(argv[4], sep='\t')