# Zero Shot Topic Classification with Transformers

https://joeddav.github.io/blog/2020/05/29/ZSL.html

https://colab.research.google.com/github/joeddav/blog/blob/master/_notebooks/2020-05-29-ZSL.ipynb#scrollTo=La_ga8KvSFYd

https://huggingface.co/spaces/joeddav/zero-shot-demo

## 1. Configuration

### 1.1 Setup colab environment

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

from google.colab import drive
drive.mount('/content/drive')

output_path = "drive/MyDrive/Classification-EDdA/"

### 1.2 Import libraries

In [1]:
import pandas as pd
from tqdm import tqdm
from transformers import BartForSequenceClassification, BartTokenizer

## 2. Load datasets

#### 2.1 Download datasets

In [None]:
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/EDdA_dataframe_withContent.tsv
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/training_set.tsv
!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/test_set.tsv

In [2]:
dataset_path = 'EDdA_dataframe_withContent.tsv'
training_set_path = 'training_set.tsv'
test_set_path = 'test_set.tsv'

input_path = '/Users/lmoncla/Nextcloud-LIRIS/GEODE/GEODE - Partage consortium/Classification domaines EDdA/datasets/'
#input_path = ''
output_path = ''

In [3]:
df = pd.read_csv(input_path + test_set_path, sep="\t")
df.head()

Unnamed: 0,volume,numero,head,normClass,classEDdA,author,id_enccre,domaine_enccre,ensemble_domaine_enccre,content,contentWithoutClass,firstParagraph,nb_word
0,11,2973,ORNIS,Commerce,Comm.,unsigned,v11-1767-0,commerce,Commerce,"ORNIS, s. m. toile des Indes, (Comm.) sortes d...","ORNIS, s. m. toile des Indes, () sortes de\nto...","ORNIS, s. m. toile des Indes, () sortes de\nto...",45
1,3,3525,COMPRENDRE,Philosophie,"terme de Philosophie,",Diderot,v3-1722-0,,,"* COMPRENDRE, v. act. terme de Philosophie,\nc...","* COMPRENDRE, v. act. \nc'est appercevoir la l...","* COMPRENDRE, v. act. \nc'est appercevoir la l...",92
2,1,2560,ANCRE,Marine,Marine,d'Alembert & Diderot,v1-1865-0,marine,Marine,"ANCRE, s. f. (Marine.) est un instrument de fe...","ANCRE, s. f. (.) est un instrument de fer\nABC...","ANCRE, s. f. (.) est un instrument de fer\nABC...",3327
3,16,4241,VAKEBARO,G√©ographie moderne,G√©og. mod.,unsigned,v16-2587-0,g√©ographie,G√©ographie,"VAKEBARO, (G√©og. mod.) vall√©e du royaume\nd'Es...","VAKEBARO, () vall√©e du royaume\nd'Espagne dans...","VAKEBARO, () vall√©e du royaume\nd'Espagne dans...",34
4,8,3281,INSPECTEUR,Histoire ancienne,Hist. anc.,unsigned,v8-2533-0,histoire,Histoire,"INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...","INSPECTEUR, s. m. inspector ; () celui \n√† qui...","INSPECTEUR, s. m. inspector ; () celui \n√† qui...",102


In [4]:
df.shape

(15854, 13)

In [5]:
#column_text = 'contentWithoutClass'
column_text = 'content'
column_class = 'ensemble_domaine_enccre'

In [6]:
df = df.dropna(subset=[column_text, column_class]).reset_index(drop=True)

In [7]:
df.shape

(13441, 13)

In [8]:
classes = df[column_class].unique().tolist()
classes

['Commerce',
 'Marine',
 'G√©ographie',
 'Histoire',
 'Belles-lettres - Po√©sie',
 'Economie domestique',
 'Droit - Jurisprudence',
 'M√©decine - Chirurgie',
 'Militaire (Art) - Guerre - Arme',
 'Beaux-arts',
 'Antiquit√©',
 'Histoire naturelle',
 'Grammaire',
 'Philosophie',
 'Arts et m√©tiers',
 'Pharmacie',
 'Religion',
 'P√™che',
 'Anatomie',
 'Architecture',
 'Musique',
 'Jeu',
 'Caract√®res',
 'M√©tiers',
 'Physique - [Sciences physico-math√©matiques]',
 'Mar√©chage - Man√®ge',
 'Chimie',
 'Blason',
 'Chasse',
 'Math√©matiques',
 'M√©dailles',
 'Superstition',
 'Agriculture - Economie rustique',
 'Mesure',
 'Monnaie',
 'Min√©ralogie',
 'Politique',
 'Spectacle']

## 3. Classification

The approach, proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161), uses a pre-trained MNLI sequence-pair classifier as an out-of-the-box zero-shot text classifier that actually works pretty well. The idea is to take the sequence we're interested in labeling as the "premise" and to turn each candidate label into a "hypothesis." If the NLI model predicts that the premise "entails" the hypothesis, we take the label to be true. See the code snippet below which demonstrates how easily this can be done with ü§ó Transformers.

In [9]:
# load model pretrained on MNLI
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')

In [None]:
''' 
## Example from: https://joeddav.github.io/blog/2020/05/29/ZSL.html

# load model pretrained on MNLI
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')

# pose sequence as a NLI premise and label (politics) as a hypothesis
premise = 'Who are you voting for in 2020?'
hypothesis = 'This text is about politics.'

# run through model pre-trained on MNLI
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
print(f'Probability that the label is true: {true_prob:0.2f}%')
'''

In [10]:
def zero_shot_prediction(premise, hypotheses):
    # list to store the true probability of each hypothesis
    true_probs = []

    # loop through hypotheses
    for hypothesis in hypotheses:

        # run through model pre-trained on MNLI
        input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
        logits = model(input_ids)[0]

        # we throw away "neutral" (dim 1) and take the probability of
        # "entailment" (2) as the probability of the label being true 
        entail_contradiction_logits = logits[:,[0,2]]
        probs = entail_contradiction_logits.softmax(dim=1)
        true_prob = probs[:,1].item() * 100

        # append true probability to list
        true_probs.append(true_prob)

    return true_probs


def get_highest_score(true_probs, hypotheses):
    
    # get index of hypothesis with highest score
    highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])

    # get hypothesis with highest score
    highest_hypothesis = hypotheses[highest_index]

    # get highest probability
    highest_prob = true_probs[highest_index]
    
    return (highest_hypothesis, highest_prob)


def get_sorted_scores(true_probs, hypotheses):

   # sort hypotheses based on their scores
    sorted_hypotheses = [hypothesis for _, hypothesis in sorted(zip(true_probs, hypotheses), reverse=True)]

    # sort scores
    sorted_scores = sorted(true_probs, reverse=True)
    
    return list(zip(sorted_hypotheses, sorted_scores))
    

In [11]:
# test
premise = df[column_text].tolist()[0]

true_probs = zero_shot_prediction(premise, classes)
highest_score = get_highest_score(true_probs, classes)

# print the results
print(f'The hypothesis with the highest score is: "{highest_score[0]}" with a probability of {highest_score[1]:0.2f}%')


probs = get_sorted_scores(true_probs, classes)
probs

The hypothesis with the highest score is: "Commerce" with a probability of 70.05%


[('Commerce', 70.05096077919006),
 ('Anatomie', 68.73840689659119),
 ('Politique', 60.71174740791321),
 ('G√©ographie', 59.156250953674316),
 ('Architecture', 58.74174237251282),
 ('Histoire', 57.459235191345215),
 ('Agriculture - Economie rustique', 53.53081226348877),
 ('Histoire naturelle', 48.459288477897644),
 ('Antiquit√©', 46.68458700180054),
 ('Beaux-arts', 42.856183648109436),
 ('Mesure', 41.31035804748535),
 ('Jeu', 41.22118949890137),
 ('Droit - Jurisprudence', 41.1332905292511),
 ('Min√©ralogie', 38.137245178222656),
 ('Spectacle', 37.80339956283569),
 ('P√™che', 37.214648723602295),
 ('Superstition', 36.727988719940186),
 ('Arts et m√©tiers', 36.511969566345215),
 ('M√©tiers', 36.5054726600647),
 ('Monnaie', 35.89862287044525),
 ('Musique', 32.74966776371002),
 ('Math√©matiques', 32.70111680030823),
 ('Chasse', 29.35197949409485),
 ('Economie domestique', 28.346234560012817),
 ('Philosophie', 27.653270959854126),
 ('Chimie', 25.783824920654297),
 ('Physique - [Sciences phy

In [13]:
y_true = df[column_class].tolist()

In [14]:
def get_tsv_content(y_true, prob_labels):
    c = ''
    for i, row in enumerate(prob_labels):
        c += y_true[i] + '\t'
        for t in row:
            c += t[0] + '\t' + str(t[1])+'\t'
        c += '\n'

    return c

In [15]:
texts = df[column_text].tolist()
batch_size = 20

for i in tqdm(range(0, len(texts), batch_size)):
    batch = texts[i:i+batch_size]
    batch_y_true = y_true[i:i+batch_size]

    prob_labels = []

    for content in batch:

        true_probs = zero_shot_prediction(content[:512], classes)
        
        #pred_labels.append(get_highest_score(true_probs, classes)[0])
        prob_labels.append(get_sorted_scores(true_probs, classes))

    with open('zero-shot-classification.tsv', 'a') as f:
        f.write(get_tsv_content(batch_y_true, prob_labels))
    #print(prob_labels) 


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 673/673 [54:36:49<00:00, 292.14s/it]   
