From 769ee67ad7b4c137b96ab639c86978af889657e3 Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <moncla.ludovic@gmail.com>
Date: Sat, 14 Jan 2023 17:38:25 +0100
Subject: [PATCH] Update Classification_Zero-Shot-Learning.ipynb
---
.../Classification_Zero-Shot-Learning.ipynb | 473 +++++++++++++++++-
1 file changed, 457 insertions(+), 16 deletions(-)
diff --git a/notebooks/Classification_Zero-Shot-Learning.ipynb b/notebooks/Classification_Zero-Shot-Learning.ipynb
index 78e4214..2869c84 100644
--- a/notebooks/Classification_Zero-Shot-Learning.ipynb
+++ b/notebooks/Classification_Zero-Shot-Learning.ipynb
@@ -80,7 +80,7 @@
"metadata": {},
"outputs": [],
"source": [
- "path = \"drive/MyDrive/Classification-EDdA/\""
+ "output_path = \"drive/MyDrive/Classification-EDdA/\""
]
},
{
@@ -95,7 +95,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {
"id": "bcptSr6o3ac7"
},
@@ -143,19 +143,22 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"dataset_path = 'EDdA_dataframe_withContent.tsv'\n",
"training_set_path = 'training_set.tsv'\n",
"test_set_path = 'test_set.tsv'\n",
- "\n"
+ "\n",
+ "input_path = '/Users/lmoncla/Nextcloud-LIRIS/GEODE/GEODE - Partage consortium/Classification domaines EDdA/datasets/'\n",
+ "#input_path = ''\n",
+ "output_path = ''"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -163,26 +166,296 @@
"id": "LRKJzWmf3pCg",
"outputId": "686c3ef4-8267-4266-95af-7193725aadca"
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>volume</th>\n",
+ " <th>numero</th>\n",
+ " <th>head</th>\n",
+ " <th>normClass</th>\n",
+ " <th>classEDdA</th>\n",
+ " <th>author</th>\n",
+ " <th>id_enccre</th>\n",
+ " <th>domaine_enccre</th>\n",
+ " <th>ensemble_domaine_enccre</th>\n",
+ " <th>content</th>\n",
+ " <th>contentWithoutClass</th>\n",
+ " <th>firstParagraph</th>\n",
+ " <th>nb_word</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td>11</td>\n",
+ " <td>2973</td>\n",
+ " <td>ORNIS</td>\n",
+ " <td>Commerce</td>\n",
+ " <td>Comm.</td>\n",
+ " <td>unsigned</td>\n",
+ " <td>v11-1767-0</td>\n",
+ " <td>commerce</td>\n",
+ " <td>Commerce</td>\n",
+ " <td>ORNIS, s. m. toile des Indes, (Comm.) sortes d...</td>\n",
+ " <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n",
+ " <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n",
+ " <td>45</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>1</th>\n",
+ " <td>3</td>\n",
+ " <td>3525</td>\n",
+ " <td>COMPRENDRE</td>\n",
+ " <td>Philosophie</td>\n",
+ " <td>terme de Philosophie,</td>\n",
+ " <td>Diderot</td>\n",
+ " <td>v3-1722-0</td>\n",
+ " <td>NaN</td>\n",
+ " <td>NaN</td>\n",
+ " <td>* COMPRENDRE, v. act. terme de Philosophie,\\nc...</td>\n",
+ " <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n",
+ " <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n",
+ " <td>92</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>2</th>\n",
+ " <td>1</td>\n",
+ " <td>2560</td>\n",
+ " <td>ANCRE</td>\n",
+ " <td>Marine</td>\n",
+ " <td>Marine</td>\n",
+ " <td>d'Alembert & Diderot</td>\n",
+ " <td>v1-1865-0</td>\n",
+ " <td>marine</td>\n",
+ " <td>Marine</td>\n",
+ " <td>ANCRE, s. f. (Marine.) est un instrument de fe...</td>\n",
+ " <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n",
+ " <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n",
+ " <td>3327</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>3</th>\n",
+ " <td>16</td>\n",
+ " <td>4241</td>\n",
+ " <td>VAKEBARO</td>\n",
+ " <td>Géographie moderne</td>\n",
+ " <td>Géog. mod.</td>\n",
+ " <td>unsigned</td>\n",
+ " <td>v16-2587-0</td>\n",
+ " <td>géographie</td>\n",
+ " <td>Géographie</td>\n",
+ " <td>VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es...</td>\n",
+ " <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n",
+ " <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n",
+ " <td>34</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>4</th>\n",
+ " <td>8</td>\n",
+ " <td>3281</td>\n",
+ " <td>INSPECTEUR</td>\n",
+ " <td>Histoire ancienne</td>\n",
+ " <td>Hist. anc.</td>\n",
+ " <td>unsigned</td>\n",
+ " <td>v8-2533-0</td>\n",
+ " <td>histoire</td>\n",
+ " <td>Histoire</td>\n",
+ " <td>INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...</td>\n",
+ " <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n",
+ " <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n",
+ " <td>102</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " volume numero head normClass classEDdA \\\n",
+ "0 11 2973 ORNIS Commerce Comm. \n",
+ "1 3 3525 COMPRENDRE Philosophie terme de Philosophie, \n",
+ "2 1 2560 ANCRE Marine Marine \n",
+ "3 16 4241 VAKEBARO Géographie moderne Géog. mod. \n",
+ "4 8 3281 INSPECTEUR Histoire ancienne Hist. anc. \n",
+ "\n",
+ " author id_enccre domaine_enccre ensemble_domaine_enccre \\\n",
+ "0 unsigned v11-1767-0 commerce Commerce \n",
+ "1 Diderot v3-1722-0 NaN NaN \n",
+ "2 d'Alembert & Diderot v1-1865-0 marine Marine \n",
+ "3 unsigned v16-2587-0 géographie Géographie \n",
+ "4 unsigned v8-2533-0 histoire Histoire \n",
+ "\n",
+ " content \\\n",
+ "0 ORNIS, s. m. toile des Indes, (Comm.) sortes d... \n",
+ "1 * COMPRENDRE, v. act. terme de Philosophie,\\nc... \n",
+ "2 ANCRE, s. f. (Marine.) est un instrument de fe... \n",
+ "3 VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es... \n",
+ "4 INSPECTEUR, s. m. inspector ; (Hist. anc.) cel... \n",
+ "\n",
+ " contentWithoutClass \\\n",
+ "0 ORNIS, s. m. toile des Indes, () sortes de\\nto... \n",
+ "1 * COMPRENDRE, v. act. \\nc'est appercevoir la l... \n",
+ "2 ANCRE, s. f. (.) est un instrument de fer\\nABC... \n",
+ "3 VAKEBARO, () vallée du royaume\\nd'Espagne dans... \n",
+ "4 INSPECTEUR, s. m. inspector ; () celui \\nà qui... \n",
+ "\n",
+ " firstParagraph nb_word \n",
+ "0 ORNIS, s. m. toile des Indes, () sortes de\\nto... 45 \n",
+ "1 * COMPRENDRE, v. act. \\nc'est appercevoir la l... 92 \n",
+ "2 ANCRE, s. f. (.) est un instrument de fer\\nABC... 3327 \n",
+ "3 VAKEBARO, () vallée du royaume\\nd'Espagne dans... 34 \n",
+ "4 INSPECTEUR, s. m. inspector ; () celui \\nà qui... 102 "
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "df = pd.read_csv(test_set_path, sep=\"\\t\")\n",
- "\n",
+ "df = pd.read_csv(input_path + test_set_path, sep=\"\\t\")\n",
"df.head()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(15854, 13)"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
- "column_text = 'contentWithoutClass'\n",
+ "#column_text = 'contentWithoutClass'\n",
+ "column_text = 'content'\n",
"column_class = 'ensemble_domaine_enccre'"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = df.dropna(subset=[column_text, column_class]).reset_index(drop=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(13441, 13)"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['Commerce',\n",
+ " 'Marine',\n",
+ " 'Géographie',\n",
+ " 'Histoire',\n",
+ " 'Belles-lettres - Poésie',\n",
+ " 'Economie domestique',\n",
+ " 'Droit - Jurisprudence',\n",
+ " 'Médecine - Chirurgie',\n",
+ " 'Militaire (Art) - Guerre - Arme',\n",
+ " 'Beaux-arts',\n",
+ " 'Antiquité',\n",
+ " 'Histoire naturelle',\n",
+ " 'Grammaire',\n",
+ " 'Philosophie',\n",
+ " 'Arts et métiers',\n",
+ " 'Pharmacie',\n",
+ " 'Religion',\n",
+ " 'Pêche',\n",
+ " 'Anatomie',\n",
+ " 'Architecture',\n",
+ " 'Musique',\n",
+ " 'Jeu',\n",
+ " 'Caractères',\n",
+ " 'Métiers',\n",
+ " 'Physique - [Sciences physico-mathématiques]',\n",
+ " 'Maréchage - Manège',\n",
+ " 'Chimie',\n",
+ " 'Blason',\n",
+ " 'Chasse',\n",
+ " 'Mathématiques',\n",
+ " 'Médailles',\n",
+ " 'Superstition',\n",
+ " 'Agriculture - Economie rustique',\n",
+ " 'Mesure',\n",
+ " 'Monnaie',\n",
+ " 'Minéralogie',\n",
+ " 'Politique',\n",
+ " 'Spectacle']"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "classes = df[column_class].unique().tolist()\n",
+ "classes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@@ -191,7 +464,18 @@
"id": "RbsHOiJdNYRL",
"outputId": "bbdafc35-cf09-4a20-c3c0-901b8adce561"
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"ORNIS, s. m. toile des Indes, (Comm.) sortes de\\ntoiles de coton ou de mousseline, qui se font a Brampour ville de l'Indoustan, entre Surate & Agra. Ces\\ntoiles sont par bandes, moitié coton & moitié or &\\nargent. Il y en a depuis quinze jusqu'à vingt aunes.\""
+ ]
+ },
+ "execution_count": 35,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"df[column_text].tolist()[0]"
]
@@ -214,9 +498,80 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 30,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e5a45c55993f47019fbdc0aceda84def",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading: 0%| | 0.00/899k [00:00<?, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a7285f3fe7154920a0bb05fdb921d6f9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading: 0%| | 0.00/456k [00:00<?, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0cb38fe0e7934c49bbce246e88dd6e53",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading: 0%| | 0.00/26.0 [00:00<?, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2c9dd205446a4b25848f1f06beeac8ae",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading: 0%| | 0.00/1.15k [00:00<?, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "032334b5310d4588993b7d85329d916c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading: 0%| | 0.00/1.63G [00:00<?, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
"# load model pretrained on MNLI\n",
"tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
@@ -229,6 +584,13 @@
"metadata": {},
"outputs": [],
"source": [
+ "''' \n",
+ "## Example from: https://joeddav.github.io/blog/2020/05/29/ZSL.html\n",
+ "\n",
+ "# load model pretrained on MNLI\n",
+ "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
+ "model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')\n",
+ "\n",
"# pose sequence as a NLI premise and label (politics) as a hypothesis\n",
"premise = 'Who are you voting for in 2020?'\n",
"hypothesis = 'This text is about politics.'\n",
@@ -242,8 +604,79 @@
"entail_contradiction_logits = logits[:,[0,2]]\n",
"probs = entail_contradiction_logits.softmax(dim=1)\n",
"true_prob = probs[:,1].item() * 100\n",
- "print(f'Probability that the label is true: {true_prob:0.2f}%')"
+ "print(f'Probability that the label is true: {true_prob:0.2f}%')\n",
+ "'''"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The hypothesis with the highest score is: \"Commerce\" with a probability of 70.05%\n"
+ ]
+ }
+ ],
+ "source": [
+ "# pose sequence as a NLI premise and label (politics) as a hypothesis\n",
+ "premise = df[column_text].tolist()[0]\n",
+ "#hypothesis = 'This text is about politics.'\n",
+ "hypotheses = classes\n",
+ "\n",
+ "# list to store the true probability of each hypothesis\n",
+ "true_probs = []\n",
+ "\n",
+ "# loop through hypotheses\n",
+ "for hypothesis in hypotheses:\n",
+ "\n",
+ " # run through model pre-trained on MNLI\n",
+ " input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n",
+ " logits = model(input_ids)[0]\n",
+ "\n",
+ " # we throw away \"neutral\" (dim 1) and take the probability of\n",
+ " # \"entailment\" (2) as the probability of the label being true \n",
+ " entail_contradiction_logits = logits[:,[0,2]]\n",
+ " probs = entail_contradiction_logits.softmax(dim=1)\n",
+ " true_prob = probs[:,1].item() * 100\n",
+ "\n",
+ " # append true probability to list\n",
+ " true_probs.append(true_prob)\n",
+ "\n",
+ "# print the true probability for each hypothesis\n",
+ "#for i, hypothesis in enumerate(hypotheses):\n",
+ "# print(f'Probability that hypothesis \"{hypothesis}\" is true: {true_probs[i]:0.2f}%')\n",
+ "# print(f'Probability that the label is true: {true_prob:0.2f}%')\n",
+ "\n",
+ "# get index of hypothesis with highest score\n",
+ "highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])\n",
+ "\n",
+ "# get hypothesis with highest score\n",
+ "highest_hypothesis = hypotheses[highest_index]\n",
+ "\n",
+ "# get highest probability\n",
+ "highest_prob = true_probs[highest_index]\n",
+ "\n",
+ "# print the results\n",
+ "print(f'The hypothesis with the highest score is: \"{highest_hypothesis}\" with a probability of {highest_prob:0.2f}%')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
@@ -259,8 +692,16 @@
"name": "python3"
},
"language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
"name": "python",
- "version": "3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n[Clang 13.0.1 ]"
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
},
"vscode": {
"interpreter": {
--
GitLab