Skip to content
Snippets Groups Projects
Commit 33f6e1b8 authored by Ludovic Moncla's avatar Ludovic Moncla
Browse files

Update Classification_Zero-Shot-Learning.ipynb

parent 769ee67a
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Zero Shot Topic Classification with Transformers
https://joeddav.github.io/blog/2020/05/29/ZSL.html
https://colab.research.google.com/github/joeddav/blog/blob/master/_notebooks/2020-05-29-ZSL.ipynb#scrollTo=La_ga8KvSFYd
https://huggingface.co/spaces/joeddav/zero-shot-demo
%% Cell type:markdown id: tags:
## 1. Configuration
%% Cell type:markdown id: tags:
### 1.1 Setup colab environment
%% Cell type:code id: tags:
``` python
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))
if ram_gb < 20:
print('Not using a high-RAM runtime')
else:
print('You are using a high-RAM runtime!')
```
%% Cell type:code id: tags:
``` python
from google.colab import drive
drive.mount('/content/drive')
```
%% Cell type:code id: tags:
``` python
output_path = "drive/MyDrive/Classification-EDdA/"
```
%% Cell type:markdown id: tags:
### 1.2 Import libraries
%% Cell type:code id: tags:
``` python
import pandas as pd
from transformers import BartForSequenceClassification, BartTokenizer
```
%% Cell type:markdown id: tags:
## 2. Load datasets
%% Cell type:markdown id: tags:
#### 2.1 Download datasets
%% Cell type:code id: tags:
``` python
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/EDdA_dataframe_withContent.tsv
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/training_set.tsv
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/test_set.tsv
```
%% Cell type:code id: tags:
``` python
dataset_path = 'EDdA_dataframe_withContent.tsv'
training_set_path = 'training_set.tsv'
test_set_path = 'test_set.tsv'
input_path = '/Users/lmoncla/Nextcloud-LIRIS/GEODE/GEODE - Partage consortium/Classification domaines EDdA/datasets/'
#input_path = ''
output_path = ''
```
%% Cell type:code id: tags:
``` python
df = pd.read_csv(input_path + test_set_path, sep="\t")
df.head()
```
%% Output
volume numero head normClass classEDdA \
0 11 2973 ORNIS Commerce Comm.
1 3 3525 COMPRENDRE Philosophie terme de Philosophie,
2 1 2560 ANCRE Marine Marine
3 16 4241 VAKEBARO Géographie moderne Géog. mod.
4 8 3281 INSPECTEUR Histoire ancienne Hist. anc.
author id_enccre domaine_enccre ensemble_domaine_enccre \
0 unsigned v11-1767-0 commerce Commerce
1 Diderot v3-1722-0 NaN NaN
2 d'Alembert & Diderot v1-1865-0 marine Marine
3 unsigned v16-2587-0 géographie Géographie
4 unsigned v8-2533-0 histoire Histoire
content \
0 ORNIS, s. m. toile des Indes, (Comm.) sortes d...
1 * COMPRENDRE, v. act. terme de Philosophie,\nc...
2 ANCRE, s. f. (Marine.) est un instrument de fe...
3 VAKEBARO, (Géog. mod.) vallée du royaume\nd'Es...
4 INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...
contentWithoutClass \
0 ORNIS, s. m. toile des Indes, () sortes de\nto...
1 * COMPRENDRE, v. act. \nc'est appercevoir la l...
2 ANCRE, s. f. (.) est un instrument de fer\nABC...
3 VAKEBARO, () vallée du royaume\nd'Espagne dans...
4 INSPECTEUR, s. m. inspector ; () celui \nà qui...
firstParagraph nb_word
0 ORNIS, s. m. toile des Indes, () sortes de\nto... 45
1 * COMPRENDRE, v. act. \nc'est appercevoir la l... 92
2 ANCRE, s. f. (.) est un instrument de fer\nABC... 3327
3 VAKEBARO, () vallée du royaume\nd'Espagne dans... 34
4 INSPECTEUR, s. m. inspector ; () celui \nà qui... 102
%% Cell type:code id: tags:
``` python
df.shape
```
%% Output
(15854, 13)
%% Cell type:code id: tags:
``` python
#column_text = 'contentWithoutClass'
column_text = 'content'
column_class = 'ensemble_domaine_enccre'
```
%% Cell type:code id: tags:
``` python
df = df.dropna(subset=[column_text, column_class]).reset_index(drop=True)
```
%% Cell type:code id: tags:
``` python
df.shape
```
%% Output
(13441, 13)
%% Cell type:code id: tags:
``` python
classes = df[column_class].unique().tolist()
classes
```
%% Output
['Commerce',
'Marine',
'Géographie',
'Histoire',
'Belles-lettres - Poésie',
'Economie domestique',
'Droit - Jurisprudence',
'Médecine - Chirurgie',
'Militaire (Art) - Guerre - Arme',
'Beaux-arts',
'Antiquité',
'Histoire naturelle',
'Grammaire',
'Philosophie',
'Arts et métiers',
'Pharmacie',
'Religion',
'Pêche',
'Anatomie',
'Architecture',
'Musique',
'Jeu',
'Caractères',
'Métiers',
'Physique - [Sciences physico-mathématiques]',
'Maréchage - Manège',
'Chimie',
'Blason',
'Chasse',
'Mathématiques',
'Médailles',
'Superstition',
'Agriculture - Economie rustique',
'Mesure',
'Monnaie',
'Minéralogie',
'Politique',
'Spectacle']
%% Cell type:code id: tags:
``` python
df[column_text].tolist()[0]
```
%% Output
"ORNIS, s. m. toile des Indes, (Comm.) sortes de\ntoiles de coton ou de mousseline, qui se font a Brampour ville de l'Indoustan, entre Surate & Agra. Ces\ntoiles sont par bandes, moitié coton & moitié or &\nargent. Il y en a depuis quinze jusqu'à vingt aunes."
%% Cell type:markdown id: tags:
## 3. Classification
%% Cell type:markdown id: tags:
The approach, proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161), uses a pre-trained MNLI sequence-pair classifier as an out-of-the-box zero-shot text classifier that actually works pretty well. The idea is to take the sequence we're interested in labeling as the "premise" and to turn each candidate label into a "hypothesis." If the NLI model predicts that the premise "entails" the hypothesis, we take the label to be true. See the code snippet below which demonstrates how easily this can be done with 🤗 Transformers.
%% Cell type:code id: tags:
``` python
# load model pretrained on MNLI
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
```
%% Output
%% Cell type:code id: tags:
``` python
'''
## Example from: https://joeddav.github.io/blog/2020/05/29/ZSL.html
# load model pretrained on MNLI
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
# pose sequence as a NLI premise and label (politics) as a hypothesis
premise = 'Who are you voting for in 2020?'
hypothesis = 'This text is about politics.'
# run through model pre-trained on MNLI
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]
# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
print(f'Probability that the label is true: {true_prob:0.2f}%')
'''
```
%% Cell type:code id: tags:
``` python
# pose sequence as a NLI premise and label (politics) as a hypothesis
premise = df[column_text].tolist()[0]
#hypothesis = 'This text is about politics.'
hypotheses = classes
def zero_shot_prediction(premise, hypotheses):
# list to store the true probability of each hypothesis
true_probs = []
# list to store the true probability of each hypothesis
true_probs = []
# loop through hypotheses
for hypothesis in hypotheses:
# loop through hypotheses
for hypothesis in hypotheses:
# run through model pre-trained on MNLI
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]
# run through model pre-trained on MNLI
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]
# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
# append true probability to list
true_probs.append(true_prob)
# append true probability to list
true_probs.append(true_prob)
return true_probs
# print the true probability for each hypothesis
#for i, hypothesis in enumerate(hypotheses):
# print(f'Probability that hypothesis "{hypothesis}" is true: {true_probs[i]:0.2f}%')
# print(f'Probability that the label is true: {true_prob:0.2f}%')
# get index of hypothesis with highest score
highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])
def get_highest_score(true_probs, hypotheses):
# print the true probability for each hypothesis
#for i, hypothesis in enumerate(hypotheses):
# print(f'Probability that hypothesis "{hypothesis}" is true: {true_probs[i]:0.2f}%')
# print(f'Probability that the label is true: {true_prob:0.2f}%')
# get hypothesis with highest score
highest_hypothesis = hypotheses[highest_index]
# get index of hypothesis with highest score
highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])
# get highest probability
highest_prob = true_probs[highest_index]
# get hypothesis with highest score
highest_hypothesis = hypotheses[highest_index]
# get highest probability
highest_prob = true_probs[highest_index]
return (highest_hypothesis, highest_prob)
# print the results
print(f'The hypothesis with the highest score is: "{highest_hypothesis}" with a probability of {highest_prob:0.2f}%')
```
%% Output
%% Cell type:code id: tags:
The hypothesis with the highest score is: "Commerce" with a probability of 70.05%
``` python
df[column_text].tolist()[0]
```
%% Cell type:code id: tags:
``` python
premise = df[column_text].tolist()[0]
true_probs = zero_shot_prediction(premise, classes)
highest_score = get_highest_score(true_probs, classes)
# print the results
print(f'The hypothesis with the highest score is: "{highest_score[0]}" with a probability of {highest_score[1]:0.2f}%')
```
%% Output
The hypothesis with the highest score is: "Commerce" with a probability of 70.05%
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment