Skip to content
Snippets Groups Projects
Commit 45694284 authored by Ludovic Moncla's avatar Ludovic Moncla
Browse files

Update Classification_Zero-Shot-Learning.ipynb

parent bf0eff0f
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Zero Shot Topic Classification with Transformers
https://joeddav.github.io/blog/2020/05/29/ZSL.html
https://colab.research.google.com/github/joeddav/blog/blob/master/_notebooks/2020-05-29-ZSL.ipynb#scrollTo=La_ga8KvSFYd
https://huggingface.co/spaces/joeddav/zero-shot-demo
%% Cell type:markdown id: tags:
## 1. Configuration
%% Cell type:markdown id: tags:
### 1.1 Setup colab environment
%% Cell type:code id: tags:
``` python
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))
if ram_gb < 20:
print('Not using a high-RAM runtime')
else:
print('You are using a high-RAM runtime!')
from google.colab import drive
drive.mount('/content/drive')
output_path = "drive/MyDrive/Classification-EDdA/"
```
%% Cell type:markdown id: tags:
### 1.2 Import libraries
%% Cell type:code id: tags:
``` python
import pandas as pd
from tqdm import tqdm
from transformers import BartForSequenceClassification, BartTokenizer
```
%% Cell type:markdown id: tags:
## 2. Load datasets
%% Cell type:markdown id: tags:
#### 2.1 Download datasets
%% Cell type:code id: tags:
``` python
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/EDdA_dataframe_withContent.tsv
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/training_set.tsv
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/test_set.tsv
```
%% Cell type:code id: tags:
``` python
dataset_path = 'EDdA_dataframe_withContent.tsv'
training_set_path = 'training_set.tsv'
test_set_path = 'test_set.tsv'
input_path = '/Users/lmoncla/Nextcloud-LIRIS/GEODE/GEODE - Partage consortium/Classification domaines EDdA/datasets/'
#input_path = ''
output_path = ''
```
%% Cell type:code id: tags:
``` python
df = pd.read_csv(input_path + test_set_path, sep="\t")
df.head()
```
%% Output
volume numero head normClass classEDdA \
0 11 2973 ORNIS Commerce Comm.
1 3 3525 COMPRENDRE Philosophie terme de Philosophie,
2 1 2560 ANCRE Marine Marine
3 16 4241 VAKEBARO Géographie moderne Géog. mod.
4 8 3281 INSPECTEUR Histoire ancienne Hist. anc.
author id_enccre domaine_enccre ensemble_domaine_enccre \
0 unsigned v11-1767-0 commerce Commerce
1 Diderot v3-1722-0 NaN NaN
2 d'Alembert & Diderot v1-1865-0 marine Marine
3 unsigned v16-2587-0 géographie Géographie
4 unsigned v8-2533-0 histoire Histoire
content \
0 ORNIS, s. m. toile des Indes, (Comm.) sortes d...
1 * COMPRENDRE, v. act. terme de Philosophie,\nc...
2 ANCRE, s. f. (Marine.) est un instrument de fe...
3 VAKEBARO, (Géog. mod.) vallée du royaume\nd'Es...
4 INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...
contentWithoutClass \
0 ORNIS, s. m. toile des Indes, () sortes de\nto...
1 * COMPRENDRE, v. act. \nc'est appercevoir la l...
2 ANCRE, s. f. (.) est un instrument de fer\nABC...
3 VAKEBARO, () vallée du royaume\nd'Espagne dans...
4 INSPECTEUR, s. m. inspector ; () celui \nà qui...
firstParagraph nb_word
0 ORNIS, s. m. toile des Indes, () sortes de\nto... 45
1 * COMPRENDRE, v. act. \nc'est appercevoir la l... 92
2 ANCRE, s. f. (.) est un instrument de fer\nABC... 3327
3 VAKEBARO, () vallée du royaume\nd'Espagne dans... 34
4 INSPECTEUR, s. m. inspector ; () celui \nà qui... 102
%% Cell type:code id: tags:
``` python
df.shape
```
%% Output
(15854, 13)
%% Cell type:code id: tags:
``` python
#column_text = 'contentWithoutClass'
column_text = 'content'
column_class = 'ensemble_domaine_enccre'
```
%% Cell type:code id: tags:
``` python
df = df.dropna(subset=[column_text, column_class]).reset_index(drop=True)
```
%% Cell type:code id: tags:
``` python
df.shape
```
%% Output
(13441, 13)
%% Cell type:code id: tags:
``` python
classes = df[column_class].unique().tolist()
classes
```
%% Output
['Commerce',
'Marine',
'Géographie',
'Histoire',
'Belles-lettres - Poésie',
'Economie domestique',
'Droit - Jurisprudence',
'Médecine - Chirurgie',
'Militaire (Art) - Guerre - Arme',
'Beaux-arts',
'Antiquité',
'Histoire naturelle',
'Grammaire',
'Philosophie',
'Arts et métiers',
'Pharmacie',
'Religion',
'Pêche',
'Anatomie',
'Architecture',
'Musique',
'Jeu',
'Caractères',
'Métiers',
'Physique - [Sciences physico-mathématiques]',
'Maréchage - Manège',
'Chimie',
'Blason',
'Chasse',
'Mathématiques',
'Médailles',
'Superstition',
'Agriculture - Economie rustique',
'Mesure',
'Monnaie',
'Minéralogie',
'Politique',
'Spectacle']
%% Cell type:markdown id: tags:
## 3. Classification
%% Cell type:markdown id: tags:
The approach, proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161), uses a pre-trained MNLI sequence-pair classifier as an out-of-the-box zero-shot text classifier that actually works pretty well. The idea is to take the sequence we're interested in labeling as the "premise" and to turn each candidate label into a "hypothesis." If the NLI model predicts that the premise "entails" the hypothesis, we take the label to be true. See the code snippet below which demonstrates how easily this can be done with 🤗 Transformers.
%% Cell type:code id: tags:
``` python
# load model pretrained on MNLI
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
```
%% Cell type:code id: tags:
``` python
'''
## Example from: https://joeddav.github.io/blog/2020/05/29/ZSL.html
# load model pretrained on MNLI
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
# pose sequence as a NLI premise and label (politics) as a hypothesis
premise = 'Who are you voting for in 2020?'
hypothesis = 'This text is about politics.'
# run through model pre-trained on MNLI
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]
# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
print(f'Probability that the label is true: {true_prob:0.2f}%')
'''
```
%% Cell type:code id: tags:
``` python
def zero_shot_prediction(premise, hypotheses):
# list to store the true probability of each hypothesis
true_probs = []
# loop through hypotheses
for hypothesis in hypotheses:
# run through model pre-trained on MNLI
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]
# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
# append true probability to list
true_probs.append(true_prob)
return true_probs
def get_highest_score(true_probs, hypotheses):
# get index of hypothesis with highest score
highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])
# get hypothesis with highest score
highest_hypothesis = hypotheses[highest_index]
# get highest probability
highest_prob = true_probs[highest_index]
return (highest_hypothesis, highest_prob)
def get_sorted_scores(true_probs, hypotheses):
# sort hypotheses based on their scores
sorted_hypotheses = [hypothesis for _, hypothesis in sorted(zip(true_probs, hypotheses), reverse=True)]
# sort scores
sorted_scores = sorted(true_probs, reverse=True)
return list(zip(sorted_hypotheses, sorted_scores))
```
%% Cell type:code id: tags:
``` python
# test
premise = df[column_text].tolist()[0]
true_probs = zero_shot_prediction(premise, classes)
highest_score = get_highest_score(true_probs, classes)
# print the results
print(f'The hypothesis with the highest score is: "{highest_score[0]}" with a probability of {highest_score[1]:0.2f}%')
probs = get_sorted_scores(true_probs, classes)
probs
```
%% Output
The hypothesis with the highest score is: "Commerce" with a probability of 70.05%
[('Commerce', 70.05096077919006),
('Anatomie', 68.73840689659119),
('Politique', 60.71174740791321),
('Géographie', 59.156250953674316),
('Architecture', 58.74174237251282),
('Histoire', 57.459235191345215),
('Agriculture - Economie rustique', 53.53081226348877),
('Histoire naturelle', 48.459288477897644),
('Antiquité', 46.68458700180054),
('Beaux-arts', 42.856183648109436),
('Mesure', 41.31035804748535),
('Jeu', 41.22118949890137),
('Droit - Jurisprudence', 41.1332905292511),
('Minéralogie', 38.137245178222656),
('Spectacle', 37.80339956283569),
('Pêche', 37.214648723602295),
('Superstition', 36.727988719940186),
('Arts et métiers', 36.511969566345215),
('Métiers', 36.5054726600647),
('Monnaie', 35.89862287044525),
('Musique', 32.74966776371002),
('Mathématiques', 32.70111680030823),
('Chasse', 29.35197949409485),
('Economie domestique', 28.346234560012817),
('Philosophie', 27.653270959854126),
('Chimie', 25.783824920654297),
('Physique - [Sciences physico-mathématiques]', 25.4037082195282),
('Médailles', 24.58679974079132),
('Grammaire', 22.36253321170807),
('Caractères', 20.14845609664917),
('Pharmacie', 19.720394909381866),
('Militaire (Art) - Guerre - Arme', 19.682711362838745),
('Médecine - Chirurgie', 18.615825474262238),
('Marine', 18.208028376102448),
('Belles-lettres - Poésie', 13.306896388530731),
('Blason', 10.476677119731903),
('Religion', 9.702161699533463),
('Maréchage - Manège', 4.211411997675896)]
%% Cell type:code id: tags:
``` python
y_true = df[column_class].tolist()
```
%% Cell type:code id: tags:
``` python
def get_tsv_content(y_true, prob_labels):
c = ''
for i, row in enumerate(prob_labels):
c += y_true[i] + '\t'
for t in row:
c += t[0] + '\t' + str(t[1])+'\t'
c += '\n'
return c
```
%% Cell type:code id: tags:
``` python
texts = df[column_text].tolist()
batch_size = 20
for i in tqdm(range(0, len(texts), batch_size)):
batch = texts[i:i+batch_size]
batch_y_true = y_true[i:i+batch_size]
prob_labels = []
for content in batch:
true_probs = zero_shot_prediction(content[:512], classes)
#pred_labels.append(get_highest_score(true_probs, classes)[0])
prob_labels.append(get_sorted_scores(true_probs, classes))
with open('zero-shot-classification.tsv', 'a') as f:
f.write(get_tsv_content(batch_y_true, prob_labels))
#print(prob_labels)
```
%% Output
0%| | 0/673 [00:00<?, ?it/s]
100%|██████████| 673/673 [54:36:49<00:00, 292.14s/it]
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment