Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "aXLlx8vXQlJw"
},
"source": [
"# Zero Shot Topic Classification with Transformers\n",
"\n",
"https://joeddav.github.io/blog/2020/05/29/ZSL.html\n",
"\n",
"https://colab.research.google.com/github/joeddav/blog/blob/master/_notebooks/2020-05-29-ZSL.ipynb#scrollTo=La_ga8KvSFYd\n",
"\n",
"https://huggingface.co/spaces/joeddav/zero-shot-demo"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "3kYI_pq3Q1BT"
},
"source": [
"## 1. Configuration"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "P_L0rDhZQ6Fn"
},
"source": [
"### 1.1 Setup colab environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FsAR4CsB3aUc",
"outputId": "e0791012-6858-4ee0-f724-7f33c6985ee8"
},
"outputs": [],
"source": [
"from psutil import virtual_memory\n",
"ram_gb = virtual_memory().total / 1e9\n",
"print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n",
"\n",
"if ram_gb < 20:\n",
" print('Not using a high-RAM runtime')\n",
"else:\n",
" print('You are using a high-RAM runtime!')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "h5MwRwL53aYY",
"outputId": "20a93907-e5df-47b1-9172-d1693ef76dc5"
},
"outputs": [],
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output_path = \"drive/MyDrive/Classification-EDdA/\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "4z78CLYi75kV"
},
"source": [
"### 1.2 Import libraries"
]
},
{
"cell_type": "code",
"metadata": {
"id": "bcptSr6o3ac7"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"from tqdm import tqdm\n",
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"from transformers import BartForSequenceClassification, BartTokenizer\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "Lc1DRh4b7mto"
},
"source": [
"## 2. Load datasets"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 2.1 Download datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ybiJYL0h3ahh",
"outputId": "0638f9a2-f9a0-4d96-9760-991ddc5747ca"
},
"outputs": [],
"source": [
"!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/EDdA_dataframe_withContent.tsv\n",
"!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/training_set.tsv\n",
"!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/test_set.tsv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset_path = 'EDdA_dataframe_withContent.tsv'\n",
"training_set_path = 'training_set.tsv'\n",
"test_set_path = 'test_set.tsv'\n",
"\n",
"input_path = '/Users/lmoncla/Nextcloud-LIRIS/GEODE/GEODE - Partage consortium/Classification domaines EDdA/datasets/'\n",
"#input_path = ''\n",
"output_path = ''"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LRKJzWmf3pCg",
"outputId": "686c3ef4-8267-4266-95af-7193725aadca"
},
"df = pd.read_csv(input_path + test_set_path, sep=\"\\t\")\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"df.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#column_text = 'contentWithoutClass'\n",
"column_text = 'content'\n",
"column_class = 'ensemble_domaine_enccre'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = df.dropna(subset=[column_text, column_class]).reset_index(drop=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"df.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"classes = df[column_class].unique().tolist()\n",
"classes"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Classification"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The approach, proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161), uses a pre-trained MNLI sequence-pair classifier as an out-of-the-box zero-shot text classifier that actually works pretty well. The idea is to take the sequence we're interested in labeling as the \"premise\" and to turn each candidate label into a \"hypothesis.\" If the NLI model predicts that the premise \"entails\" the hypothesis, we take the label to be true. See the code snippet below which demonstrates how easily this can be done with 🤗 Transformers."
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# load model pretrained on MNLI\n",
"tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
"model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"''' \n",
"## Example from: https://joeddav.github.io/blog/2020/05/29/ZSL.html\n",
"\n",
"# load model pretrained on MNLI\n",
"tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
"model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')\n",
"\n",
"# pose sequence as a NLI premise and label (politics) as a hypothesis\n",
"premise = 'Who are you voting for in 2020?'\n",
"hypothesis = 'This text is about politics.'\n",
"\n",
"# run through model pre-trained on MNLI\n",
"input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n",
"logits = model(input_ids)[0]\n",
"\n",
"# we throw away \"neutral\" (dim 1) and take the probability of\n",
"# \"entailment\" (2) as the probability of the label being true \n",
"entail_contradiction_logits = logits[:,[0,2]]\n",
"probs = entail_contradiction_logits.softmax(dim=1)\n",
"true_prob = probs[:,1].item() * 100\n",
"print(f'Probability that the label is true: {true_prob:0.2f}%')\n",
"'''"
},
{
"cell_type": "code",
"execution_count": null,
"def zero_shot_prediction(premise, hypotheses):\n",
" # list to store the true probability of each hypothesis\n",
" true_probs = []\n",
" # loop through hypotheses\n",
" for hypothesis in hypotheses:\n",
" # run through model pre-trained on MNLI\n",
" input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n",
" logits = model(input_ids)[0]\n",
" # we throw away \"neutral\" (dim 1) and take the probability of\n",
" # \"entailment\" (2) as the probability of the label being true \n",
" entail_contradiction_logits = logits[:,[0,2]]\n",
" probs = entail_contradiction_logits.softmax(dim=1)\n",
" true_prob = probs[:,1].item() * 100\n",
" # append true probability to list\n",
" true_probs.append(true_prob)\n",
"def get_highest_score(true_probs, hypotheses):\n",
" # get index of hypothesis with highest score\n",
" highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])\n",
" # get hypothesis with highest score\n",
" highest_hypothesis = hypotheses[highest_index]\n",
" # get highest probability\n",
" highest_prob = true_probs[highest_index]\n",
" \n",
" return (highest_hypothesis, highest_prob)\n",
"\n",
"\n",
"def get_sorted_scores(true_probs, hypotheses):\n",
"\n",
" # sort hypotheses based on their scores\n",
" sorted_hypotheses = [hypothesis for _, hypothesis in sorted(zip(true_probs, hypotheses), reverse=True)]\n",
"\n",
" # sort scores\n",
" sorted_scores = sorted(true_probs, reverse=True)\n",
" \n",
" return list(zip(sorted_hypotheses, sorted_scores))\n",
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"df[column_text].tolist()[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"premise = df[column_text].tolist()[0]\n",
"\n",
"true_probs = zero_shot_prediction(premise, classes)\n",
"highest_score = get_highest_score(true_probs, classes)\n",
"\n",
"# print the results\n",
"print(f'The hypothesis with the highest score is: \"{highest_score[0]}\" with a probability of {highest_score[1]:0.2f}%')\n",
"\n",
"\n",
"probs = get_sorted_scores(true_probs, classes)\n",
"probs\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"probs[0][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_labels = []\n",
"prob_labels = []\n",
"\n",
"for content in tqdm(df[column_text].tolist()):\n",
"\n",
" true_probs = zero_shot_prediction(content[:1024], classes)\n",
" \n",
" pred_labels.append(get_highest_score(true_probs, classes)[0])\n",
" prob_labels.append(get_sorted_scores(true_probs, classes))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": [
"true_labels = df[column_class].tolist()"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "EDdA-Classification_Clustering.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "geode-classification-py39",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"