Skip to content
Snippets Groups Projects
predict.py 1 KiB
Newer Older
Alice Brenon's avatar
Alice Brenon committed
from BERT import Classifier
Alice Brenon's avatar
Alice Brenon committed
from Corpus import corpus
Alice Brenon's avatar
Alice Brenon committed
def label(classify, source, name='label'):
    """
    Make predictions on a set of document

    Positional arguments
Alice Brenon's avatar
Alice Brenon committed
    :param classify: an instance of the Classifier class
    :param source: an instance of the Corpus class

    Keyword arguments
    :param name: defaults to 'label' — the name of the column to be created, that is
    to say, the name of the category you are predicting with your model (if your
    model labels in "Red", "Green", or "Blue", you may want to use
    `name='color'`).

    :return: a panda dataframe containing the records from the input TSV file plus
    an additional column
    """
Alice Brenon's avatar
Alice Brenon committed
    records = pandas.DataFrame(source.get_all('key'))
    records[name] = classify(source.get_all('content'))
    return records

if __name__ == '__main__':
    classify = Classifier(argv[1])
Alice Brenon's avatar
Alice Brenon committed
    source = corpus(argv[2])
    label(classify, source).to_csv(argv[3], sep='\t', index=False)