diff --git a/EDdA/classification/classSimilarities.py b/EDdA/classification/classSimilarities.py index bcc7c057db53ab2c300cf37fa345fd2a9a816e03..e36e3902361984fd061e46effdcb953924f40bb4 100644 --- a/EDdA/classification/classSimilarities.py +++ b/EDdA/classification/classSimilarities.py @@ -33,10 +33,11 @@ def confusionMatrix(vectorizer, metric, domains=data.domains): m[a][b] = metric(vectorizer(domains[a]), vectorizer(domains[b])) return m -def toPNG(matrix, filePath, domains=data.domains): +def toPNG(matrix, filePath, domains=list(map(data.shortDomain, data.domains)), **kwargs): plot.figure(figsize=(16,13)) + if 'cmap' not in kwargs: + kwargs['cmap'] = 'Blues' ax = seaborn.heatmap( - matrix, xticklabels=domains, yticklabels=domains, cmap='Blues' + matrix, xticklabels=domains, yticklabels=domains, **kwargs ) plot.savefig(filePath, dpi=300, bbox_inches='tight') - diff --git a/EDdA/data.py b/EDdA/data.py index 2243c1c3b9a8da4de365023657eab832f9dd98d4..49ae5999ae1d358751a3cb4806ec500fa5736e9a 100644 --- a/EDdA/data.py +++ b/EDdA/data.py @@ -32,5 +32,12 @@ domains = [ domainId = dict([(domains[k], k) for k in range(0, len(domains))]) +def shortDomain(name, maxSize=20): + if len(name) > maxSize: + components = name.split(' ') + return components[0] + ' […]' + else: + return name + def domain(articles, name): return articles[articles.ensemble_domaine_enccre == name]