Skip to content
Snippets Groups Projects
Commit cef1e9a8 authored by Ludovic Moncla's avatar Ludovic Moncla
Browse files

Create Classification_Zero-Shot-Learning.ipynb

parent d477ca43
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Zero Shot Topic Classification with Transformers
https://joeddav.github.io/blog/2020/05/29/ZSL.html
https://colab.research.google.com/github/joeddav/blog/blob/master/_notebooks/2020-05-29-ZSL.ipynb#scrollTo=La_ga8KvSFYd
https://huggingface.co/spaces/joeddav/zero-shot-demo
%% Cell type:markdown id: tags:
## 1. Configuration
%% Cell type:markdown id: tags:
### 1.1 Setup colab environment
%% Cell type:code id: tags:
``` python
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))
if ram_gb < 20:
print('Not using a high-RAM runtime')
else:
print('You are using a high-RAM runtime!')
```
%% Cell type:code id: tags:
``` python
from google.colab import drive
drive.mount('/content/drive')
```
%% Cell type:code id: tags:
``` python
path = "drive/MyDrive/Classification-EDdA/"
```
%% Cell type:markdown id: tags:
### 1.2 Import libraries
%% Cell type:code id: tags:
``` python
import pandas as pd
from transformers import BartForSequenceClassification, BartTokenizer
```
%% Cell type:markdown id: tags:
## 2. Load datasets
%% Cell type:markdown id: tags:
#### 2.1 Download datasets
%% Cell type:code id: tags:
``` python
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/EDdA_dataframe_withContent.tsv
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/training_set.tsv
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/test_set.tsv
```
%% Cell type:code id: tags:
``` python
dataset_path = 'EDdA_dataframe_withContent.tsv'
training_set_path = 'training_set.tsv'
test_set_path = 'test_set.tsv'
```
%% Cell type:code id: tags:
``` python
df = pd.read_csv(test_set_path, sep="\t")
df.head()
```
%% Cell type:code id: tags:
``` python
column_text = 'contentWithoutClass'
column_class = 'ensemble_domaine_enccre'
```
%% Cell type:code id: tags:
``` python
df[column_text].tolist()[0]
```
%% Cell type:markdown id: tags:
## 3. Classification
%% Cell type:markdown id: tags:
The approach, proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161), uses a pre-trained MNLI sequence-pair classifier as an out-of-the-box zero-shot text classifier that actually works pretty well. The idea is to take the sequence we're interested in labeling as the "premise" and to turn each candidate label into a "hypothesis." If the NLI model predicts that the premise "entails" the hypothesis, we take the label to be true. See the code snippet below which demonstrates how easily this can be done with 🤗 Transformers.
%% Cell type:code id: tags:
``` python
# load model pretrained on MNLI
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
```
%% Cell type:code id: tags:
``` python
# pose sequence as a NLI premise and label (politics) as a hypothesis
premise = 'Who are you voting for in 2020?'
hypothesis = 'This text is about politics.'
# run through model pre-trained on MNLI
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]
# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
print(f'Probability that the label is true: {true_prob:0.2f}%')
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment