Skip to content
Snippets Groups Projects
predict.py 4.28 KiB
Newer Older
#!/usr/bin/env python3
import numpy
import pandas
import pickle
import sklearn
from sys import argv
import torch
from tqdm import tqdm
from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline

class Classifier:
    """
    A class wrapping all the different models and classes used throughout a
    classification task:

        - tokenizer
        - classifier
        - pipeline
        - label encoder

    Once created, it behaves as a function which you apply to a generator
    containing the texts to classify
    """
    def __init__(self, root_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._init_tokenizer()
        self._init_model(root_path)
        self._init_pipe()
        self._init_encoder(f"{root_path}/label_encoder.pkl")
        self.log()

    def _init_model(self, path):
        bert = BertForSequenceClassification.from_pretrained(path)
        self.model = bert.to(self.device.type)

    def _init_tokenizer(self):
        model_name = 'bert-base-multilingual-cased'
        self.tokenizer = BertTokenizer.from_pretrained(model_name)

    def _init_pipe(self):
        self.pipe = TextClassificationPipeline(
            model=self.model,
            tokenizer=self.tokenizer,
            return_all_scores=True,
            device=self.device)

    def _init_encoder(self, path):
        with open(path, 'rb') as pickled:
            self.encoder = pickle.load(pickled)

    def log(self):
        if self.device.type == 'cpu':
            print('No GPU available, using the CPU instead.')
        else:
            print('We will use the GPU:', torch.cuda.get_device_name(0))

    def __call__(self, text_generator):
        tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
        predictions = []
        for output in tqdm(self.pipe(text_generator, **tokenizer_kwargs)):
            byScoreDesc = sorted(output, key=lambda d: d['score'], reverse=True)
            predictions.append([int(byScoreDesc[0]['label'][6:]),
                                byScoreDesc[0]['score'],
                                int(byScoreDesc[1]['label'][6:])])
        predictions = numpy.array(predictions)
        return list(self.encoder.inverse_transform(predictions[:,0].astype(int)))

class Source:
    """
    A class to handle the normalised path used in the project and loading the
    actual text input as a generator from records when they are needed
    """
    def __init__(self, root_path):
        """
        Positional arguments
        :param root_path: the path to a GÉODE-style folder containing the text
        version of the corpus on which to predict the classes
        """
        self.root_path = root_path

    def path_to(self, record):
        article_relative_path = "{work}/T{volume}/{article}".format(**record)
        prefix = f"{self.root_path}/{article_relative_path}"
        if 'paragraph' in record:
            return f"{prefix}/{record.paragraph}.txt"
        else:
            return f"{prefix}.txt"

    def load_text(self, record):
        with open(self.path_to(record), 'r') as file:
            return file.read()

    def iterate(self, records):
        for _, record in records.iterrows():
            yield self.load_text(record)

def label(classify, source, tsv_path, name='label'):
    """
    Make predictions on a set of document

    Positional arguments
    :param classify: an instance of the Classifier class above
    :param source: an instance of the Source class above
    :param tsv_path: the path to a TSV file containing (at least) article or
    paragraph records (additional metadata will be ignored)

    Keyword arguments
    :param name: defaults to 'label' — the name of the column to be created, that is
    to say, the name of the category you are predicting with your model (if your
    model labels in "Red", "Green", or "Blue", you may want to use
    `name='color'`).

    :return: a panda dataframe containing the records from the input TSV file plus
    an additional column
    """
    records = pandas.read_csv(tsv_path, sep='\t')
    records[name] = classify(source.iterate(records))
    return records

if __name__ == '__main__':
    classify = Classifier(argv[1])
    source = Source(argv[2])
    label(classify, source, argv[3]).to_csv(argv[4], sep='\t', index=False)