Skip to content
Snippets Groups Projects
Commit 85986ba9 authored by Alice Brenon's avatar Alice Brenon
Browse files

Add function to EDdA.data module to keep only the first word of multi-word...

Add function to EDdA.data module to keep only the first word of multi-word domain labels longer than 20
parent c704a10e
No related branches found
No related tags found
No related merge requests found
...@@ -33,10 +33,11 @@ def confusionMatrix(vectorizer, metric, domains=data.domains): ...@@ -33,10 +33,11 @@ def confusionMatrix(vectorizer, metric, domains=data.domains):
m[a][b] = metric(vectorizer(domains[a]), vectorizer(domains[b])) m[a][b] = metric(vectorizer(domains[a]), vectorizer(domains[b]))
return m return m
def toPNG(matrix, filePath, domains=data.domains): def toPNG(matrix, filePath, domains=list(map(data.shortDomain, data.domains)), **kwargs):
plot.figure(figsize=(16,13)) plot.figure(figsize=(16,13))
if 'cmap' not in kwargs:
kwargs['cmap'] = 'Blues'
ax = seaborn.heatmap( ax = seaborn.heatmap(
matrix, xticklabels=domains, yticklabels=domains, cmap='Blues' matrix, xticklabels=domains, yticklabels=domains, **kwargs
) )
plot.savefig(filePath, dpi=300, bbox_inches='tight') plot.savefig(filePath, dpi=300, bbox_inches='tight')
...@@ -32,5 +32,12 @@ domains = [ ...@@ -32,5 +32,12 @@ domains = [
domainId = dict([(domains[k], k) for k in range(0, len(domains))]) domainId = dict([(domains[k], k) for k in range(0, len(domains))])
def shortDomain(name, maxSize=20):
if len(name) > maxSize:
components = name.split(' ')
return components[0] + ' […]'
else:
return name
def domain(articles, name): def domain(articles, name):
return articles[articles.ensemble_domaine_enccre == name] return articles[articles.ensemble_domaine_enccre == name]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment