From 9cfce11e78a38c6a39761ef3dfb04ff7012936ba Mon Sep 17 00:00:00 2001 From: Ludovic Moncla <moncla.ludovic@gmail.com> Date: Mon, 16 Jan 2023 15:02:16 +0100 Subject: [PATCH] Update Classification_Zero-Shot-Learning.ipynb --- .../Classification_Zero-Shot-Learning.ipynb | 469 ++---------------- 1 file changed, 29 insertions(+), 440 deletions(-) diff --git a/notebooks/Classification_Zero-Shot-Learning.ipynb b/notebooks/Classification_Zero-Shot-Learning.ipynb index ad73e6a..2165378 100644 --- a/notebooks/Classification_Zero-Shot-Learning.ipynb +++ b/notebooks/Classification_Zero-Shot-Learning.ipynb @@ -95,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "id": "bcptSr6o3ac7" }, @@ -103,7 +103,7 @@ "source": [ "import pandas as pd\n", "from tqdm import tqdm\n", - "from transformers import BartForSequenceClassification, BartTokenizer\n" + "from transformers import BartForSequenceClassification, BartTokenizer" ] }, { @@ -143,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -158,7 +158,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -166,170 +166,7 @@ "id": "LRKJzWmf3pCg", "outputId": "686c3ef4-8267-4266-95af-7193725aadca" }, - "outputs": [ - { - "data": { - "text/html": [ - "<div>\n", - "<style scoped>\n", - " .dataframe tbody tr th:only-of-type {\n", - " vertical-align: middle;\n", - " }\n", - "\n", - " .dataframe tbody tr th {\n", - " vertical-align: top;\n", - " }\n", - "\n", - " .dataframe thead th {\n", - " text-align: right;\n", - " }\n", - "</style>\n", - "<table border=\"1\" class=\"dataframe\">\n", - " <thead>\n", - " <tr style=\"text-align: right;\">\n", - " <th></th>\n", - " <th>volume</th>\n", - " <th>numero</th>\n", - " <th>head</th>\n", - " <th>normClass</th>\n", - " <th>classEDdA</th>\n", - " <th>author</th>\n", - " <th>id_enccre</th>\n", - " <th>domaine_enccre</th>\n", - " <th>ensemble_domaine_enccre</th>\n", - " <th>content</th>\n", - " <th>contentWithoutClass</th>\n", - " <th>firstParagraph</th>\n", - " <th>nb_word</th>\n", - " </tr>\n", - " </thead>\n", - " <tbody>\n", - " <tr>\n", - " <th>0</th>\n", - " <td>11</td>\n", - " <td>2973</td>\n", - " <td>ORNIS</td>\n", - " <td>Commerce</td>\n", - " <td>Comm.</td>\n", - " <td>unsigned</td>\n", - " <td>v11-1767-0</td>\n", - " <td>commerce</td>\n", - " <td>Commerce</td>\n", - " <td>ORNIS, s. m. toile des Indes, (Comm.) sortes d...</td>\n", - " <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n", - " <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n", - " <td>45</td>\n", - " </tr>\n", - " <tr>\n", - " <th>1</th>\n", - " <td>3</td>\n", - " <td>3525</td>\n", - " <td>COMPRENDRE</td>\n", - " <td>Philosophie</td>\n", - " <td>terme de Philosophie,</td>\n", - " <td>Diderot</td>\n", - " <td>v3-1722-0</td>\n", - " <td>NaN</td>\n", - " <td>NaN</td>\n", - " <td>* COMPRENDRE, v. act. terme de Philosophie,\\nc...</td>\n", - " <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n", - " <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n", - " <td>92</td>\n", - " </tr>\n", - " <tr>\n", - " <th>2</th>\n", - " <td>1</td>\n", - " <td>2560</td>\n", - " <td>ANCRE</td>\n", - " <td>Marine</td>\n", - " <td>Marine</td>\n", - " <td>d'Alembert & Diderot</td>\n", - " <td>v1-1865-0</td>\n", - " <td>marine</td>\n", - " <td>Marine</td>\n", - " <td>ANCRE, s. f. (Marine.) est un instrument de fe...</td>\n", - " <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n", - " <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n", - " <td>3327</td>\n", - " </tr>\n", - " <tr>\n", - " <th>3</th>\n", - " <td>16</td>\n", - " <td>4241</td>\n", - " <td>VAKEBARO</td>\n", - " <td>Géographie moderne</td>\n", - " <td>Géog. mod.</td>\n", - " <td>unsigned</td>\n", - " <td>v16-2587-0</td>\n", - " <td>géographie</td>\n", - " <td>Géographie</td>\n", - " <td>VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es...</td>\n", - " <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n", - " <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n", - " <td>34</td>\n", - " </tr>\n", - " <tr>\n", - " <th>4</th>\n", - " <td>8</td>\n", - " <td>3281</td>\n", - " <td>INSPECTEUR</td>\n", - " <td>Histoire ancienne</td>\n", - " <td>Hist. anc.</td>\n", - " <td>unsigned</td>\n", - " <td>v8-2533-0</td>\n", - " <td>histoire</td>\n", - " <td>Histoire</td>\n", - " <td>INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...</td>\n", - " <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n", - " <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n", - " <td>102</td>\n", - " </tr>\n", - " </tbody>\n", - "</table>\n", - "</div>" - ], - "text/plain": [ - " volume numero head normClass classEDdA \\\n", - "0 11 2973 ORNIS Commerce Comm. \n", - "1 3 3525 COMPRENDRE Philosophie terme de Philosophie, \n", - "2 1 2560 ANCRE Marine Marine \n", - "3 16 4241 VAKEBARO Géographie moderne Géog. mod. \n", - "4 8 3281 INSPECTEUR Histoire ancienne Hist. anc. \n", - "\n", - " author id_enccre domaine_enccre ensemble_domaine_enccre \\\n", - "0 unsigned v11-1767-0 commerce Commerce \n", - "1 Diderot v3-1722-0 NaN NaN \n", - "2 d'Alembert & Diderot v1-1865-0 marine Marine \n", - "3 unsigned v16-2587-0 géographie Géographie \n", - "4 unsigned v8-2533-0 histoire Histoire \n", - "\n", - " content \\\n", - "0 ORNIS, s. m. toile des Indes, (Comm.) sortes d... \n", - "1 * COMPRENDRE, v. act. terme de Philosophie,\\nc... \n", - "2 ANCRE, s. f. (Marine.) est un instrument de fe... \n", - "3 VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es... \n", - "4 INSPECTEUR, s. m. inspector ; (Hist. anc.) cel... \n", - "\n", - " contentWithoutClass \\\n", - "0 ORNIS, s. m. toile des Indes, () sortes de\\nto... \n", - "1 * COMPRENDRE, v. act. \\nc'est appercevoir la l... \n", - "2 ANCRE, s. f. (.) est un instrument de fer\\nABC... \n", - "3 VAKEBARO, () vallée du royaume\\nd'Espagne dans... \n", - "4 INSPECTEUR, s. m. inspector ; () celui \\nà qui... \n", - "\n", - " firstParagraph nb_word \n", - "0 ORNIS, s. m. toile des Indes, () sortes de\\nto... 45 \n", - "1 * COMPRENDRE, v. act. \\nc'est appercevoir la l... 92 \n", - "2 ANCRE, s. f. (.) est un instrument de fer\\nABC... 3327 \n", - "3 VAKEBARO, () vallée du royaume\\nd'Espagne dans... 34 \n", - "4 INSPECTEUR, s. m. inspector ; () celui \\nà qui... 102 " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "df = pd.read_csv(input_path + test_set_path, sep=\"\\t\")\n", "df.head()" @@ -337,27 +174,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(15854, 13)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "df.shape" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -368,7 +194,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -377,77 +203,18 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(13441, 13)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "df.shape" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['Commerce',\n", - " 'Marine',\n", - " 'Géographie',\n", - " 'Histoire',\n", - " 'Belles-lettres - Poésie',\n", - " 'Economie domestique',\n", - " 'Droit - Jurisprudence',\n", - " 'Médecine - Chirurgie',\n", - " 'Militaire (Art) - Guerre - Arme',\n", - " 'Beaux-arts',\n", - " 'Antiquité',\n", - " 'Histoire naturelle',\n", - " 'Grammaire',\n", - " 'Philosophie',\n", - " 'Arts et métiers',\n", - " 'Pharmacie',\n", - " 'Religion',\n", - " 'Pêche',\n", - " 'Anatomie',\n", - " 'Architecture',\n", - " 'Musique',\n", - " 'Jeu',\n", - " 'Caractères',\n", - " 'Métiers',\n", - " 'Physique - [Sciences physico-mathématiques]',\n", - " 'Maréchage - Manège',\n", - " 'Chimie',\n", - " 'Blason',\n", - " 'Chasse',\n", - " 'Mathématiques',\n", - " 'Médailles',\n", - " 'Superstition',\n", - " 'Agriculture - Economie rustique',\n", - " 'Mesure',\n", - " 'Monnaie',\n", - " 'Minéralogie',\n", - " 'Politique',\n", - " 'Spectacle']" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "classes = df[column_class].unique().tolist()\n", "classes" @@ -471,80 +238,9 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cf58eccecc1847e48d520a83040e3ec7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading: 0%| | 0.00/899k [00:00<?, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2e91d1ab181f409f9c6263a255991c9f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading: 0%| | 0.00/456k [00:00<?, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4c23b1cb783d4c25b8332447bf25755a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading: 0%| | 0.00/26.0 [00:00<?, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "203385bbd7664b8b87250e4883cc300f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading: 0%| | 0.00/1.15k [00:00<?, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c068998923a14abf8e38d8e0d89248ad", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading: 0%| | 0.00/1.63G [00:00<?, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "# load model pretrained on MNLI\n", "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n", @@ -583,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -638,84 +334,18 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\"ORNIS, s. m. toile des Indes, (Comm.) sortes de\\ntoiles de coton ou de mousseline, qui se font a Brampour ville de l'Indoustan, entre Surate & Agra. Ces\\ntoiles sont par bandes, moitié coton & moitié or &\\nargent. Il y en a depuis quinze jusqu'à vingt aunes.\"" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "df[column_text].tolist()[0]" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The hypothesis with the highest score is: \"Commerce\" with a probability of 70.05%\n" - ] - }, - { - "data": { - "text/plain": [ - "[('Commerce', 70.05096077919006),\n", - " ('Anatomie', 68.73840689659119),\n", - " ('Politique', 60.71174740791321),\n", - " ('Géographie', 59.156250953674316),\n", - " ('Architecture', 58.74174237251282),\n", - " ('Histoire', 57.459235191345215),\n", - " ('Agriculture - Economie rustique', 53.53081226348877),\n", - " ('Histoire naturelle', 48.459288477897644),\n", - " ('Antiquité', 46.68458700180054),\n", - " ('Beaux-arts', 42.856183648109436),\n", - " ('Mesure', 41.31035804748535),\n", - " ('Jeu', 41.22118949890137),\n", - " ('Droit - Jurisprudence', 41.1332905292511),\n", - " ('Minéralogie', 38.137245178222656),\n", - " ('Spectacle', 37.80339956283569),\n", - " ('Pêche', 37.214648723602295),\n", - " ('Superstition', 36.727988719940186),\n", - " ('Arts et métiers', 36.511969566345215),\n", - " ('Métiers', 36.5054726600647),\n", - " ('Monnaie', 35.89862287044525),\n", - " ('Musique', 32.74966776371002),\n", - " ('Mathématiques', 32.70111680030823),\n", - " ('Chasse', 29.35197949409485),\n", - " ('Economie domestique', 28.346234560012817),\n", - " ('Philosophie', 27.653270959854126),\n", - " ('Chimie', 25.783824920654297),\n", - " ('Physique - [Sciences physico-mathématiques]', 25.4037082195282),\n", - " ('Médailles', 24.58679974079132),\n", - " ('Grammaire', 22.36253321170807),\n", - " ('Caractères', 20.14845609664917),\n", - " ('Pharmacie', 19.720394909381866),\n", - " ('Militaire (Art) - Guerre - Arme', 19.682711362838745),\n", - " ('Médecine - Chirurgie', 18.615825474262238),\n", - " ('Marine', 18.208028376102448),\n", - " ('Belles-lettres - Poésie', 13.306896388530731),\n", - " ('Blason', 10.476677119731903),\n", - " ('Religion', 9.702161699533463),\n", - " ('Maréchage - Manège', 4.211411997675896)]" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "premise = df[column_text].tolist()[0]\n", "\n", @@ -732,20 +362,9 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'Commerce'" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "probs[0][0]" ] @@ -778,48 +397,18 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "672.05" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "len(texts) / 20" ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 1%| | 5/673 [10:28<23:19:26, 125.70s/it]\n" - ] - }, - { - "ename": "IndexError", - "evalue": "string index out of range", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[36], line 12\u001b[0m\n\u001b[1;32m 8\u001b[0m prob_labels \u001b[39m=\u001b[39m []\n\u001b[1;32m 10\u001b[0m \u001b[39mfor\u001b[39;00m j, content \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(batch):\n\u001b[0;32m---> 12\u001b[0m true_probs \u001b[39m=\u001b[39m zero_shot_prediction(content[i][:\u001b[39m1024\u001b[39m], classes)\n\u001b[1;32m 14\u001b[0m \u001b[39m#pred_labels.append(get_highest_score(true_probs, classes)[0])\u001b[39;00m\n\u001b[1;32m 15\u001b[0m prob_labels\u001b[39m.\u001b[39mappend(get_sorted_scores(true_probs, classes))\n", - "\u001b[0;31mIndexError\u001b[0m: string index out of range" - ] - } - ], + "outputs": [], "source": [ "texts = df[column_text].tolist()\n", "batch_size = 20\n", @@ -830,9 +419,9 @@ "\n", " prob_labels = []\n", "\n", - " for j, content in enumerate(batch):\n", + " for content in batch:\n", "\n", - " true_probs = zero_shot_prediction(content[j][:1024], classes)\n", + " true_probs = zero_shot_prediction(content[:512], classes)\n", " \n", " #pred_labels.append(get_highest_score(true_probs, classes)[0])\n", " prob_labels.append(get_sorted_scores(true_probs, classes))\n", @@ -844,7 +433,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [] -- GitLab