Skip to content
Snippets Groups Projects
splitMulti.py 1.97 KiB
Newer Older
#!/usr/bin/env python3
from Corpus import Directory
from GEODE import toKey
from GEODE.Error import TwoAnnotations
from GEODE.util import initialise, parseRatio
import JSONL
from random import shuffle
from sys import argv, stdin
from prodigyAcceptedJSONLToTSV import acceptedToTSV
from prodigyMultiJSONLToDirectory import multiJSONLToDirectory

def getTexts(inputJSONL):
    texts = {}
    errors = set({})
    for annotation in inputJSONL:
        key = toKey(annotation['meta'])
        if key not in errors:
            initialise(texts, key, {'accept': None, 'reject': []})
            if annotation['answer'] == 'accept':
                previous = texts[key]['accept']
                if previous is None:
                    texts[key]['accept'] = annotation
                else:
                    print(TwoAnnotations(annotations['meta'],
                                         previous['label'],
                                         texts[key]['label']))
                    errors.add(key)
            else:
                texts[key]['reject'].append(annotation)
    return texts

def getTest(texts, trainRatio):
    accepted = [key for key, t in texts.items() if t['accept'] is not None]
    shuffle(accepted)
    size = round(len(accepted) * (1-trainRatio))
    return {key: texts[key]['accept'] for key in accepted[:size]}

def allAnnotations(text):
    if text['accept'] is None:
        return text['reject']
    else:
        return [text['accept']] + text['reject']

def getTrain(texts, test):
    return [annotation
            for key in sorted(texts.keys()) if key not in test
            for annotation in allAnnotations(texts[key])]

def splitMulti(jsonl, trainRatio, trainOutput, testOutput):
    texts = getTexts(jsonl)
    test = getTest(texts, trainRatio)
    train = getTrain(texts, test)
    multiJSONLToDirectory(train, trainOutput)
    acceptedToTSV(test.values(), testOutput)

if __name__ == '__main__':
    splitMulti(JSONL.load(stdin), parseRatio(argv[1]), argv[2], argv[3])