From 9cfce11e78a38c6a39761ef3dfb04ff7012936ba Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <moncla.ludovic@gmail.com>
Date: Mon, 16 Jan 2023 15:02:16 +0100
Subject: [PATCH] Update Classification_Zero-Shot-Learning.ipynb

---
 .../Classification_Zero-Shot-Learning.ipynb   | 469 ++----------------
 1 file changed, 29 insertions(+), 440 deletions(-)

diff --git a/notebooks/Classification_Zero-Shot-Learning.ipynb b/notebooks/Classification_Zero-Shot-Learning.ipynb
index ad73e6a..2165378 100644
--- a/notebooks/Classification_Zero-Shot-Learning.ipynb
+++ b/notebooks/Classification_Zero-Shot-Learning.ipynb
@@ -95,7 +95,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 1,
+      "execution_count": null,
       "metadata": {
         "id": "bcptSr6o3ac7"
       },
@@ -103,7 +103,7 @@
       "source": [
         "import pandas as pd\n",
         "from tqdm import tqdm\n",
-        "from transformers import BartForSequenceClassification, BartTokenizer\n"
+        "from transformers import BartForSequenceClassification, BartTokenizer"
       ]
     },
     {
@@ -143,7 +143,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 2,
+      "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -158,7 +158,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 3,
+      "execution_count": null,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
@@ -166,170 +166,7 @@
         "id": "LRKJzWmf3pCg",
         "outputId": "686c3ef4-8267-4266-95af-7193725aadca"
       },
-      "outputs": [
-        {
-          "data": {
-            "text/html": [
-              "<div>\n",
-              "<style scoped>\n",
-              "    .dataframe tbody tr th:only-of-type {\n",
-              "        vertical-align: middle;\n",
-              "    }\n",
-              "\n",
-              "    .dataframe tbody tr th {\n",
-              "        vertical-align: top;\n",
-              "    }\n",
-              "\n",
-              "    .dataframe thead th {\n",
-              "        text-align: right;\n",
-              "    }\n",
-              "</style>\n",
-              "<table border=\"1\" class=\"dataframe\">\n",
-              "  <thead>\n",
-              "    <tr style=\"text-align: right;\">\n",
-              "      <th></th>\n",
-              "      <th>volume</th>\n",
-              "      <th>numero</th>\n",
-              "      <th>head</th>\n",
-              "      <th>normClass</th>\n",
-              "      <th>classEDdA</th>\n",
-              "      <th>author</th>\n",
-              "      <th>id_enccre</th>\n",
-              "      <th>domaine_enccre</th>\n",
-              "      <th>ensemble_domaine_enccre</th>\n",
-              "      <th>content</th>\n",
-              "      <th>contentWithoutClass</th>\n",
-              "      <th>firstParagraph</th>\n",
-              "      <th>nb_word</th>\n",
-              "    </tr>\n",
-              "  </thead>\n",
-              "  <tbody>\n",
-              "    <tr>\n",
-              "      <th>0</th>\n",
-              "      <td>11</td>\n",
-              "      <td>2973</td>\n",
-              "      <td>ORNIS</td>\n",
-              "      <td>Commerce</td>\n",
-              "      <td>Comm.</td>\n",
-              "      <td>unsigned</td>\n",
-              "      <td>v11-1767-0</td>\n",
-              "      <td>commerce</td>\n",
-              "      <td>Commerce</td>\n",
-              "      <td>ORNIS, s. m. toile des Indes, (Comm.) sortes d...</td>\n",
-              "      <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n",
-              "      <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n",
-              "      <td>45</td>\n",
-              "    </tr>\n",
-              "    <tr>\n",
-              "      <th>1</th>\n",
-              "      <td>3</td>\n",
-              "      <td>3525</td>\n",
-              "      <td>COMPRENDRE</td>\n",
-              "      <td>Philosophie</td>\n",
-              "      <td>terme de Philosophie,</td>\n",
-              "      <td>Diderot</td>\n",
-              "      <td>v3-1722-0</td>\n",
-              "      <td>NaN</td>\n",
-              "      <td>NaN</td>\n",
-              "      <td>* COMPRENDRE, v. act. terme de Philosophie,\\nc...</td>\n",
-              "      <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n",
-              "      <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n",
-              "      <td>92</td>\n",
-              "    </tr>\n",
-              "    <tr>\n",
-              "      <th>2</th>\n",
-              "      <td>1</td>\n",
-              "      <td>2560</td>\n",
-              "      <td>ANCRE</td>\n",
-              "      <td>Marine</td>\n",
-              "      <td>Marine</td>\n",
-              "      <td>d'Alembert &amp; Diderot</td>\n",
-              "      <td>v1-1865-0</td>\n",
-              "      <td>marine</td>\n",
-              "      <td>Marine</td>\n",
-              "      <td>ANCRE, s. f. (Marine.) est un instrument de fe...</td>\n",
-              "      <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n",
-              "      <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n",
-              "      <td>3327</td>\n",
-              "    </tr>\n",
-              "    <tr>\n",
-              "      <th>3</th>\n",
-              "      <td>16</td>\n",
-              "      <td>4241</td>\n",
-              "      <td>VAKEBARO</td>\n",
-              "      <td>Géographie moderne</td>\n",
-              "      <td>Géog. mod.</td>\n",
-              "      <td>unsigned</td>\n",
-              "      <td>v16-2587-0</td>\n",
-              "      <td>géographie</td>\n",
-              "      <td>Géographie</td>\n",
-              "      <td>VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es...</td>\n",
-              "      <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n",
-              "      <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n",
-              "      <td>34</td>\n",
-              "    </tr>\n",
-              "    <tr>\n",
-              "      <th>4</th>\n",
-              "      <td>8</td>\n",
-              "      <td>3281</td>\n",
-              "      <td>INSPECTEUR</td>\n",
-              "      <td>Histoire ancienne</td>\n",
-              "      <td>Hist. anc.</td>\n",
-              "      <td>unsigned</td>\n",
-              "      <td>v8-2533-0</td>\n",
-              "      <td>histoire</td>\n",
-              "      <td>Histoire</td>\n",
-              "      <td>INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...</td>\n",
-              "      <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n",
-              "      <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n",
-              "      <td>102</td>\n",
-              "    </tr>\n",
-              "  </tbody>\n",
-              "</table>\n",
-              "</div>"
-            ],
-            "text/plain": [
-              "   volume  numero        head           normClass              classEDdA  \\\n",
-              "0      11    2973       ORNIS            Commerce                  Comm.   \n",
-              "1       3    3525  COMPRENDRE         Philosophie  terme de Philosophie,   \n",
-              "2       1    2560       ANCRE              Marine                 Marine   \n",
-              "3      16    4241    VAKEBARO  Géographie moderne             Géog. mod.   \n",
-              "4       8    3281  INSPECTEUR   Histoire ancienne             Hist. anc.   \n",
-              "\n",
-              "                 author   id_enccre domaine_enccre ensemble_domaine_enccre  \\\n",
-              "0              unsigned  v11-1767-0       commerce                Commerce   \n",
-              "1               Diderot   v3-1722-0            NaN                     NaN   \n",
-              "2  d'Alembert & Diderot   v1-1865-0         marine                  Marine   \n",
-              "3              unsigned  v16-2587-0     géographie              Géographie   \n",
-              "4              unsigned   v8-2533-0       histoire                Histoire   \n",
-              "\n",
-              "                                             content  \\\n",
-              "0  ORNIS, s. m. toile des Indes, (Comm.) sortes d...   \n",
-              "1  * COMPRENDRE, v. act. terme de Philosophie,\\nc...   \n",
-              "2  ANCRE, s. f. (Marine.) est un instrument de fe...   \n",
-              "3  VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es...   \n",
-              "4  INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...   \n",
-              "\n",
-              "                                 contentWithoutClass  \\\n",
-              "0  ORNIS, s. m. toile des Indes, () sortes de\\nto...   \n",
-              "1  * COMPRENDRE, v. act. \\nc'est appercevoir la l...   \n",
-              "2  ANCRE, s. f. (.) est un instrument de fer\\nABC...   \n",
-              "3  VAKEBARO, () vallée du royaume\\nd'Espagne dans...   \n",
-              "4  INSPECTEUR, s. m. inspector ; () celui \\nà qui...   \n",
-              "\n",
-              "                                      firstParagraph  nb_word  \n",
-              "0  ORNIS, s. m. toile des Indes, () sortes de\\nto...       45  \n",
-              "1  * COMPRENDRE, v. act. \\nc'est appercevoir la l...       92  \n",
-              "2  ANCRE, s. f. (.) est un instrument de fer\\nABC...     3327  \n",
-              "3  VAKEBARO, () vallée du royaume\\nd'Espagne dans...       34  \n",
-              "4  INSPECTEUR, s. m. inspector ; () celui \\nà qui...      102  "
-            ]
-          },
-          "execution_count": 3,
-          "metadata": {},
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "df = pd.read_csv(input_path + test_set_path, sep=\"\\t\")\n",
         "df.head()"
@@ -337,27 +174,16 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 4,
+      "execution_count": null,
       "metadata": {},
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "(15854, 13)"
-            ]
-          },
-          "execution_count": 4,
-          "metadata": {},
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "df.shape"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 5,
+      "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -368,7 +194,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 6,
+      "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -377,77 +203,18 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 7,
+      "execution_count": null,
       "metadata": {},
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "(13441, 13)"
-            ]
-          },
-          "execution_count": 7,
-          "metadata": {},
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "df.shape"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 8,
+      "execution_count": null,
       "metadata": {},
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "['Commerce',\n",
-              " 'Marine',\n",
-              " 'Géographie',\n",
-              " 'Histoire',\n",
-              " 'Belles-lettres - Poésie',\n",
-              " 'Economie domestique',\n",
-              " 'Droit - Jurisprudence',\n",
-              " 'Médecine - Chirurgie',\n",
-              " 'Militaire (Art) - Guerre - Arme',\n",
-              " 'Beaux-arts',\n",
-              " 'Antiquité',\n",
-              " 'Histoire naturelle',\n",
-              " 'Grammaire',\n",
-              " 'Philosophie',\n",
-              " 'Arts et métiers',\n",
-              " 'Pharmacie',\n",
-              " 'Religion',\n",
-              " 'Pêche',\n",
-              " 'Anatomie',\n",
-              " 'Architecture',\n",
-              " 'Musique',\n",
-              " 'Jeu',\n",
-              " 'Caractères',\n",
-              " 'Métiers',\n",
-              " 'Physique - [Sciences physico-mathématiques]',\n",
-              " 'Maréchage - Manège',\n",
-              " 'Chimie',\n",
-              " 'Blason',\n",
-              " 'Chasse',\n",
-              " 'Mathématiques',\n",
-              " 'Médailles',\n",
-              " 'Superstition',\n",
-              " 'Agriculture - Economie rustique',\n",
-              " 'Mesure',\n",
-              " 'Monnaie',\n",
-              " 'Minéralogie',\n",
-              " 'Politique',\n",
-              " 'Spectacle']"
-            ]
-          },
-          "execution_count": 8,
-          "metadata": {},
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "classes = df[column_class].unique().tolist()\n",
         "classes"
@@ -471,80 +238,9 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 9,
+      "execution_count": null,
       "metadata": {},
-      "outputs": [
-        {
-          "data": {
-            "application/vnd.jupyter.widget-view+json": {
-              "model_id": "cf58eccecc1847e48d520a83040e3ec7",
-              "version_major": 2,
-              "version_minor": 0
-            },
-            "text/plain": [
-              "Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]"
-            ]
-          },
-          "metadata": {},
-          "output_type": "display_data"
-        },
-        {
-          "data": {
-            "application/vnd.jupyter.widget-view+json": {
-              "model_id": "2e91d1ab181f409f9c6263a255991c9f",
-              "version_major": 2,
-              "version_minor": 0
-            },
-            "text/plain": [
-              "Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]"
-            ]
-          },
-          "metadata": {},
-          "output_type": "display_data"
-        },
-        {
-          "data": {
-            "application/vnd.jupyter.widget-view+json": {
-              "model_id": "4c23b1cb783d4c25b8332447bf25755a",
-              "version_major": 2,
-              "version_minor": 0
-            },
-            "text/plain": [
-              "Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]"
-            ]
-          },
-          "metadata": {},
-          "output_type": "display_data"
-        },
-        {
-          "data": {
-            "application/vnd.jupyter.widget-view+json": {
-              "model_id": "203385bbd7664b8b87250e4883cc300f",
-              "version_major": 2,
-              "version_minor": 0
-            },
-            "text/plain": [
-              "Downloading:   0%|          | 0.00/1.15k [00:00<?, ?B/s]"
-            ]
-          },
-          "metadata": {},
-          "output_type": "display_data"
-        },
-        {
-          "data": {
-            "application/vnd.jupyter.widget-view+json": {
-              "model_id": "c068998923a14abf8e38d8e0d89248ad",
-              "version_major": 2,
-              "version_minor": 0
-            },
-            "text/plain": [
-              "Downloading:   0%|          | 0.00/1.63G [00:00<?, ?B/s]"
-            ]
-          },
-          "metadata": {},
-          "output_type": "display_data"
-        }
-      ],
+      "outputs": [],
       "source": [
         "# load model pretrained on MNLI\n",
         "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
@@ -583,7 +279,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 10,
+      "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -638,84 +334,18 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 11,
+      "execution_count": null,
       "metadata": {},
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "\"ORNIS, s. m. toile des Indes, (Comm.) sortes de\\ntoiles de coton ou de mousseline, qui se font a Brampour ville de l'Indoustan, entre Surate & Agra. Ces\\ntoiles sont par bandes, moitié coton & moitié or &\\nargent. Il y en a depuis quinze jusqu'à vingt aunes.\""
-            ]
-          },
-          "execution_count": 11,
-          "metadata": {},
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "df[column_text].tolist()[0]"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 12,
+      "execution_count": null,
       "metadata": {},
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "The hypothesis with the highest score is: \"Commerce\" with a probability of 70.05%\n"
-          ]
-        },
-        {
-          "data": {
-            "text/plain": [
-              "[('Commerce', 70.05096077919006),\n",
-              " ('Anatomie', 68.73840689659119),\n",
-              " ('Politique', 60.71174740791321),\n",
-              " ('Géographie', 59.156250953674316),\n",
-              " ('Architecture', 58.74174237251282),\n",
-              " ('Histoire', 57.459235191345215),\n",
-              " ('Agriculture - Economie rustique', 53.53081226348877),\n",
-              " ('Histoire naturelle', 48.459288477897644),\n",
-              " ('Antiquité', 46.68458700180054),\n",
-              " ('Beaux-arts', 42.856183648109436),\n",
-              " ('Mesure', 41.31035804748535),\n",
-              " ('Jeu', 41.22118949890137),\n",
-              " ('Droit - Jurisprudence', 41.1332905292511),\n",
-              " ('Minéralogie', 38.137245178222656),\n",
-              " ('Spectacle', 37.80339956283569),\n",
-              " ('Pêche', 37.214648723602295),\n",
-              " ('Superstition', 36.727988719940186),\n",
-              " ('Arts et métiers', 36.511969566345215),\n",
-              " ('Métiers', 36.5054726600647),\n",
-              " ('Monnaie', 35.89862287044525),\n",
-              " ('Musique', 32.74966776371002),\n",
-              " ('Mathématiques', 32.70111680030823),\n",
-              " ('Chasse', 29.35197949409485),\n",
-              " ('Economie domestique', 28.346234560012817),\n",
-              " ('Philosophie', 27.653270959854126),\n",
-              " ('Chimie', 25.783824920654297),\n",
-              " ('Physique - [Sciences physico-mathématiques]', 25.4037082195282),\n",
-              " ('Médailles', 24.58679974079132),\n",
-              " ('Grammaire', 22.36253321170807),\n",
-              " ('Caractères', 20.14845609664917),\n",
-              " ('Pharmacie', 19.720394909381866),\n",
-              " ('Militaire (Art) - Guerre - Arme', 19.682711362838745),\n",
-              " ('Médecine - Chirurgie', 18.615825474262238),\n",
-              " ('Marine', 18.208028376102448),\n",
-              " ('Belles-lettres - Poésie', 13.306896388530731),\n",
-              " ('Blason', 10.476677119731903),\n",
-              " ('Religion', 9.702161699533463),\n",
-              " ('Maréchage - Manège', 4.211411997675896)]"
-            ]
-          },
-          "execution_count": 12,
-          "metadata": {},
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "premise = df[column_text].tolist()[0]\n",
         "\n",
@@ -732,20 +362,9 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 13,
+      "execution_count": null,
       "metadata": {},
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "'Commerce'"
-            ]
-          },
-          "execution_count": 13,
-          "metadata": {},
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "probs[0][0]"
       ]
@@ -778,48 +397,18 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 35,
+      "execution_count": null,
       "metadata": {},
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "672.05"
-            ]
-          },
-          "execution_count": 35,
-          "metadata": {},
-          "output_type": "execute_result"
-        }
-      ],
+      "outputs": [],
       "source": [
         "len(texts) / 20"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 36,
+      "execution_count": null,
       "metadata": {},
-      "outputs": [
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "  1%|          | 5/673 [10:28<23:19:26, 125.70s/it]\n"
-          ]
-        },
-        {
-          "ename": "IndexError",
-          "evalue": "string index out of range",
-          "output_type": "error",
-          "traceback": [
-            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-            "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
-            "Cell \u001b[0;32mIn[36], line 12\u001b[0m\n\u001b[1;32m      8\u001b[0m prob_labels \u001b[39m=\u001b[39m []\n\u001b[1;32m     10\u001b[0m \u001b[39mfor\u001b[39;00m j, content \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(batch):\n\u001b[0;32m---> 12\u001b[0m     true_probs \u001b[39m=\u001b[39m zero_shot_prediction(content[i][:\u001b[39m1024\u001b[39m], classes)\n\u001b[1;32m     14\u001b[0m     \u001b[39m#pred_labels.append(get_highest_score(true_probs, classes)[0])\u001b[39;00m\n\u001b[1;32m     15\u001b[0m     prob_labels\u001b[39m.\u001b[39mappend(get_sorted_scores(true_probs, classes))\n",
-            "\u001b[0;31mIndexError\u001b[0m: string index out of range"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "texts = df[column_text].tolist()\n",
         "batch_size = 20\n",
@@ -830,9 +419,9 @@
         "\n",
         "    prob_labels = []\n",
         "\n",
-        "    for j, content in enumerate(batch):\n",
+        "    for content in batch:\n",
         "\n",
-        "        true_probs = zero_shot_prediction(content[j][:1024], classes)\n",
+        "        true_probs = zero_shot_prediction(content[:512], classes)\n",
         "        \n",
         "        #pred_labels.append(get_highest_score(true_probs, classes)[0])\n",
         "        prob_labels.append(get_sorted_scores(true_probs, classes))\n",
@@ -844,7 +433,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 22,
+      "execution_count": null,
       "metadata": {},
       "outputs": [],
       "source": []
-- 
GitLab