from EDdA.cache import Cache
from EDdA import data
import nltk
import pandas

def frequenciesLoader(articles, n, domain):
    texts = data.domain(articles, domain).contentWithoutClass
    state = {}
    for text in texts:
        ngrams = list(nltk.ngrams(text.split(), n))
        for k in ngrams:
            state[k] = 1+(state[k] if k in state else 0)
    return state

def frequenciesPath(inputHash, n):
    return lambda domain:\
        "frequencies/{inputHash}/{n}grams/{domain}.tsv"\
                .format(inputHash=inputHash, n=n, domain=domain)

def loadFrequencies(f):
    tsv = pandas.read_csv(f, sep='\t', na_filter=False)
    return dict(zip(
            tsv.ngram.map(lambda s: tuple(s.split(','))),
            tsv.frequency
        ))

def saveFrequencies(freqs, f):
    pandas.DataFrame(data={
            'ngram': map(lambda t: ','.join(t), freqs.keys()),
            'frequency': freqs.values()
        }).to_csv(f, sep='\t', index=False)

def frequencies(source, n):
    return Cache(
            lambda domain: frequenciesLoader(source.articles, n, domain),
            pathPolicy=frequenciesPath(source.hash, n),
            serializer=saveFrequencies,
            unserializer=loadFrequencies
        )

def topLoader(frequencyEvaluator, n, ranks):
    return lambda domain:\
        dict(nltk.FreqDist(frequencyEvaluator(domain)).most_common(ranks))

def topPath(inputHash, n, ranks):
    return lambda domain:\
        "topNGrams/{inputHash}/{n}grams/top{ranks}/{domain}.tsv".format(
                inputHash=inputHash,
                n=n,
                ranks=ranks,
                domain=domain
            )

def topNGrams(source, n, ranks):
    freq = frequencies(source, n)
    return Cache(
            topLoader(freq, n, ranks),
            pathPolicy=topPath(source.hash, n, ranks),
            serializer=saveFrequencies,
            unserializer=loadFrequencies
        )