From 769ee67ad7b4c137b96ab639c86978af889657e3 Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <moncla.ludovic@gmail.com>
Date: Sat, 14 Jan 2023 17:38:25 +0100
Subject: [PATCH] Update Classification_Zero-Shot-Learning.ipynb

---
 .../Classification_Zero-Shot-Learning.ipynb   | 473 +++++++++++++++++-
 1 file changed, 457 insertions(+), 16 deletions(-)

diff --git a/notebooks/Classification_Zero-Shot-Learning.ipynb b/notebooks/Classification_Zero-Shot-Learning.ipynb
index 78e4214..2869c84 100644
--- a/notebooks/Classification_Zero-Shot-Learning.ipynb
+++ b/notebooks/Classification_Zero-Shot-Learning.ipynb
@@ -80,7 +80,7 @@
       "metadata": {},
       "outputs": [],
       "source": [
-        "path = \"drive/MyDrive/Classification-EDdA/\""
+        "output_path = \"drive/MyDrive/Classification-EDdA/\""
       ]
     },
     {
@@ -95,7 +95,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 1,
       "metadata": {
         "id": "bcptSr6o3ac7"
       },
@@ -143,19 +143,22 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 4,
       "metadata": {},
       "outputs": [],
       "source": [
         "dataset_path = 'EDdA_dataframe_withContent.tsv'\n",
         "training_set_path = 'training_set.tsv'\n",
         "test_set_path = 'test_set.tsv'\n",
-        "\n"
+        "\n",
+        "input_path = '/Users/lmoncla/Nextcloud-LIRIS/GEODE/GEODE - Partage consortium/Classification domaines EDdA/datasets/'\n",
+        "#input_path = ''\n",
+        "output_path = ''"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 18,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
@@ -163,26 +166,296 @@
         "id": "LRKJzWmf3pCg",
         "outputId": "686c3ef4-8267-4266-95af-7193725aadca"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "data": {
+            "text/html": [
+              "<div>\n",
+              "<style scoped>\n",
+              "    .dataframe tbody tr th:only-of-type {\n",
+              "        vertical-align: middle;\n",
+              "    }\n",
+              "\n",
+              "    .dataframe tbody tr th {\n",
+              "        vertical-align: top;\n",
+              "    }\n",
+              "\n",
+              "    .dataframe thead th {\n",
+              "        text-align: right;\n",
+              "    }\n",
+              "</style>\n",
+              "<table border=\"1\" class=\"dataframe\">\n",
+              "  <thead>\n",
+              "    <tr style=\"text-align: right;\">\n",
+              "      <th></th>\n",
+              "      <th>volume</th>\n",
+              "      <th>numero</th>\n",
+              "      <th>head</th>\n",
+              "      <th>normClass</th>\n",
+              "      <th>classEDdA</th>\n",
+              "      <th>author</th>\n",
+              "      <th>id_enccre</th>\n",
+              "      <th>domaine_enccre</th>\n",
+              "      <th>ensemble_domaine_enccre</th>\n",
+              "      <th>content</th>\n",
+              "      <th>contentWithoutClass</th>\n",
+              "      <th>firstParagraph</th>\n",
+              "      <th>nb_word</th>\n",
+              "    </tr>\n",
+              "  </thead>\n",
+              "  <tbody>\n",
+              "    <tr>\n",
+              "      <th>0</th>\n",
+              "      <td>11</td>\n",
+              "      <td>2973</td>\n",
+              "      <td>ORNIS</td>\n",
+              "      <td>Commerce</td>\n",
+              "      <td>Comm.</td>\n",
+              "      <td>unsigned</td>\n",
+              "      <td>v11-1767-0</td>\n",
+              "      <td>commerce</td>\n",
+              "      <td>Commerce</td>\n",
+              "      <td>ORNIS, s. m. toile des Indes, (Comm.) sortes d...</td>\n",
+              "      <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n",
+              "      <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n",
+              "      <td>45</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>1</th>\n",
+              "      <td>3</td>\n",
+              "      <td>3525</td>\n",
+              "      <td>COMPRENDRE</td>\n",
+              "      <td>Philosophie</td>\n",
+              "      <td>terme de Philosophie,</td>\n",
+              "      <td>Diderot</td>\n",
+              "      <td>v3-1722-0</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>* COMPRENDRE, v. act. terme de Philosophie,\\nc...</td>\n",
+              "      <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n",
+              "      <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n",
+              "      <td>92</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>2</th>\n",
+              "      <td>1</td>\n",
+              "      <td>2560</td>\n",
+              "      <td>ANCRE</td>\n",
+              "      <td>Marine</td>\n",
+              "      <td>Marine</td>\n",
+              "      <td>d'Alembert &amp; Diderot</td>\n",
+              "      <td>v1-1865-0</td>\n",
+              "      <td>marine</td>\n",
+              "      <td>Marine</td>\n",
+              "      <td>ANCRE, s. f. (Marine.) est un instrument de fe...</td>\n",
+              "      <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n",
+              "      <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n",
+              "      <td>3327</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>3</th>\n",
+              "      <td>16</td>\n",
+              "      <td>4241</td>\n",
+              "      <td>VAKEBARO</td>\n",
+              "      <td>Géographie moderne</td>\n",
+              "      <td>Géog. mod.</td>\n",
+              "      <td>unsigned</td>\n",
+              "      <td>v16-2587-0</td>\n",
+              "      <td>géographie</td>\n",
+              "      <td>Géographie</td>\n",
+              "      <td>VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es...</td>\n",
+              "      <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n",
+              "      <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n",
+              "      <td>34</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>4</th>\n",
+              "      <td>8</td>\n",
+              "      <td>3281</td>\n",
+              "      <td>INSPECTEUR</td>\n",
+              "      <td>Histoire ancienne</td>\n",
+              "      <td>Hist. anc.</td>\n",
+              "      <td>unsigned</td>\n",
+              "      <td>v8-2533-0</td>\n",
+              "      <td>histoire</td>\n",
+              "      <td>Histoire</td>\n",
+              "      <td>INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...</td>\n",
+              "      <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n",
+              "      <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n",
+              "      <td>102</td>\n",
+              "    </tr>\n",
+              "  </tbody>\n",
+              "</table>\n",
+              "</div>"
+            ],
+            "text/plain": [
+              "   volume  numero        head           normClass              classEDdA  \\\n",
+              "0      11    2973       ORNIS            Commerce                  Comm.   \n",
+              "1       3    3525  COMPRENDRE         Philosophie  terme de Philosophie,   \n",
+              "2       1    2560       ANCRE              Marine                 Marine   \n",
+              "3      16    4241    VAKEBARO  Géographie moderne             Géog. mod.   \n",
+              "4       8    3281  INSPECTEUR   Histoire ancienne             Hist. anc.   \n",
+              "\n",
+              "                 author   id_enccre domaine_enccre ensemble_domaine_enccre  \\\n",
+              "0              unsigned  v11-1767-0       commerce                Commerce   \n",
+              "1               Diderot   v3-1722-0            NaN                     NaN   \n",
+              "2  d'Alembert & Diderot   v1-1865-0         marine                  Marine   \n",
+              "3              unsigned  v16-2587-0     géographie              Géographie   \n",
+              "4              unsigned   v8-2533-0       histoire                Histoire   \n",
+              "\n",
+              "                                             content  \\\n",
+              "0  ORNIS, s. m. toile des Indes, (Comm.) sortes d...   \n",
+              "1  * COMPRENDRE, v. act. terme de Philosophie,\\nc...   \n",
+              "2  ANCRE, s. f. (Marine.) est un instrument de fe...   \n",
+              "3  VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es...   \n",
+              "4  INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...   \n",
+              "\n",
+              "                                 contentWithoutClass  \\\n",
+              "0  ORNIS, s. m. toile des Indes, () sortes de\\nto...   \n",
+              "1  * COMPRENDRE, v. act. \\nc'est appercevoir la l...   \n",
+              "2  ANCRE, s. f. (.) est un instrument de fer\\nABC...   \n",
+              "3  VAKEBARO, () vallée du royaume\\nd'Espagne dans...   \n",
+              "4  INSPECTEUR, s. m. inspector ; () celui \\nà qui...   \n",
+              "\n",
+              "                                      firstParagraph  nb_word  \n",
+              "0  ORNIS, s. m. toile des Indes, () sortes de\\nto...       45  \n",
+              "1  * COMPRENDRE, v. act. \\nc'est appercevoir la l...       92  \n",
+              "2  ANCRE, s. f. (.) est un instrument de fer\\nABC...     3327  \n",
+              "3  VAKEBARO, () vallée du royaume\\nd'Espagne dans...       34  \n",
+              "4  INSPECTEUR, s. m. inspector ; () celui \\nà qui...      102  "
+            ]
+          },
+          "execution_count": 18,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
       "source": [
-        "df = pd.read_csv(test_set_path, sep=\"\\t\")\n",
-        "\n",
+        "df = pd.read_csv(input_path + test_set_path, sep=\"\\t\")\n",
         "df.head()"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 19,
+      "metadata": {},
+      "outputs": [
+        {
+          "data": {
+            "text/plain": [
+              "(15854, 13)"
+            ]
+          },
+          "execution_count": 19,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
+      "source": [
+        "df.shape"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 20,
       "metadata": {},
       "outputs": [],
       "source": [
-        "column_text = 'contentWithoutClass'\n",
+        "#column_text = 'contentWithoutClass'\n",
+        "column_text = 'content'\n",
         "column_class = 'ensemble_domaine_enccre'"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 22,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "df = df.dropna(subset=[column_text, column_class]).reset_index(drop=True)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 23,
+      "metadata": {},
+      "outputs": [
+        {
+          "data": {
+            "text/plain": [
+              "(13441, 13)"
+            ]
+          },
+          "execution_count": 23,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
+      "source": [
+        "df.shape"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 32,
+      "metadata": {},
+      "outputs": [
+        {
+          "data": {
+            "text/plain": [
+              "['Commerce',\n",
+              " 'Marine',\n",
+              " 'Géographie',\n",
+              " 'Histoire',\n",
+              " 'Belles-lettres - Poésie',\n",
+              " 'Economie domestique',\n",
+              " 'Droit - Jurisprudence',\n",
+              " 'Médecine - Chirurgie',\n",
+              " 'Militaire (Art) - Guerre - Arme',\n",
+              " 'Beaux-arts',\n",
+              " 'Antiquité',\n",
+              " 'Histoire naturelle',\n",
+              " 'Grammaire',\n",
+              " 'Philosophie',\n",
+              " 'Arts et métiers',\n",
+              " 'Pharmacie',\n",
+              " 'Religion',\n",
+              " 'Pêche',\n",
+              " 'Anatomie',\n",
+              " 'Architecture',\n",
+              " 'Musique',\n",
+              " 'Jeu',\n",
+              " 'Caractères',\n",
+              " 'Métiers',\n",
+              " 'Physique - [Sciences physico-mathématiques]',\n",
+              " 'Maréchage - Manège',\n",
+              " 'Chimie',\n",
+              " 'Blason',\n",
+              " 'Chasse',\n",
+              " 'Mathématiques',\n",
+              " 'Médailles',\n",
+              " 'Superstition',\n",
+              " 'Agriculture - Economie rustique',\n",
+              " 'Mesure',\n",
+              " 'Monnaie',\n",
+              " 'Minéralogie',\n",
+              " 'Politique',\n",
+              " 'Spectacle']"
+            ]
+          },
+          "execution_count": 32,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
+      "source": [
+        "classes = df[column_class].unique().tolist()\n",
+        "classes"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 35,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/",
@@ -191,7 +464,18 @@
         "id": "RbsHOiJdNYRL",
         "outputId": "bbdafc35-cf09-4a20-c3c0-901b8adce561"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "data": {
+            "text/plain": [
+              "\"ORNIS, s. m. toile des Indes, (Comm.) sortes de\\ntoiles de coton ou de mousseline, qui se font a Brampour ville de l'Indoustan, entre Surate & Agra. Ces\\ntoiles sont par bandes, moitié coton & moitié or &\\nargent. Il y en a depuis quinze jusqu'à vingt aunes.\""
+            ]
+          },
+          "execution_count": 35,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
       "source": [
         "df[column_text].tolist()[0]"
       ]
@@ -214,9 +498,80 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 30,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "data": {
+            "application/vnd.jupyter.widget-view+json": {
+              "model_id": "e5a45c55993f47019fbdc0aceda84def",
+              "version_major": 2,
+              "version_minor": 0
+            },
+            "text/plain": [
+              "Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]"
+            ]
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        },
+        {
+          "data": {
+            "application/vnd.jupyter.widget-view+json": {
+              "model_id": "a7285f3fe7154920a0bb05fdb921d6f9",
+              "version_major": 2,
+              "version_minor": 0
+            },
+            "text/plain": [
+              "Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]"
+            ]
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        },
+        {
+          "data": {
+            "application/vnd.jupyter.widget-view+json": {
+              "model_id": "0cb38fe0e7934c49bbce246e88dd6e53",
+              "version_major": 2,
+              "version_minor": 0
+            },
+            "text/plain": [
+              "Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]"
+            ]
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        },
+        {
+          "data": {
+            "application/vnd.jupyter.widget-view+json": {
+              "model_id": "2c9dd205446a4b25848f1f06beeac8ae",
+              "version_major": 2,
+              "version_minor": 0
+            },
+            "text/plain": [
+              "Downloading:   0%|          | 0.00/1.15k [00:00<?, ?B/s]"
+            ]
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        },
+        {
+          "data": {
+            "application/vnd.jupyter.widget-view+json": {
+              "model_id": "032334b5310d4588993b7d85329d916c",
+              "version_major": 2,
+              "version_minor": 0
+            },
+            "text/plain": [
+              "Downloading:   0%|          | 0.00/1.63G [00:00<?, ?B/s]"
+            ]
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        }
+      ],
       "source": [
         "# load model pretrained on MNLI\n",
         "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
@@ -229,6 +584,13 @@
       "metadata": {},
       "outputs": [],
       "source": [
+        "''' \n",
+        "## Example from: https://joeddav.github.io/blog/2020/05/29/ZSL.html\n",
+        "\n",
+        "# load model pretrained on MNLI\n",
+        "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
+        "model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')\n",
+        "\n",
         "# pose sequence as a NLI premise and label (politics) as a hypothesis\n",
         "premise = 'Who are you voting for in 2020?'\n",
         "hypothesis = 'This text is about politics.'\n",
@@ -242,8 +604,79 @@
         "entail_contradiction_logits = logits[:,[0,2]]\n",
         "probs = entail_contradiction_logits.softmax(dim=1)\n",
         "true_prob = probs[:,1].item() * 100\n",
-        "print(f'Probability that the label is true: {true_prob:0.2f}%')"
+        "print(f'Probability that the label is true: {true_prob:0.2f}%')\n",
+        "'''"
       ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 36,
+      "metadata": {},
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "The hypothesis with the highest score is: \"Commerce\" with a probability of 70.05%\n"
+          ]
+        }
+      ],
+      "source": [
+        "# pose sequence as a NLI premise and label (politics) as a hypothesis\n",
+        "premise = df[column_text].tolist()[0]\n",
+        "#hypothesis = 'This text is about politics.'\n",
+        "hypotheses = classes\n",
+        "\n",
+        "# list to store the true probability of each hypothesis\n",
+        "true_probs = []\n",
+        "\n",
+        "# loop through hypotheses\n",
+        "for hypothesis in hypotheses:\n",
+        "\n",
+        "    # run through model pre-trained on MNLI\n",
+        "    input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n",
+        "    logits = model(input_ids)[0]\n",
+        "\n",
+        "    # we throw away \"neutral\" (dim 1) and take the probability of\n",
+        "    # \"entailment\" (2) as the probability of the label being true \n",
+        "    entail_contradiction_logits = logits[:,[0,2]]\n",
+        "    probs = entail_contradiction_logits.softmax(dim=1)\n",
+        "    true_prob = probs[:,1].item() * 100\n",
+        "\n",
+        "    # append true probability to list\n",
+        "    true_probs.append(true_prob)\n",
+        "\n",
+        "# print the true probability for each hypothesis\n",
+        "#for i, hypothesis in enumerate(hypotheses):\n",
+        "#    print(f'Probability that hypothesis \"{hypothesis}\" is true: {true_probs[i]:0.2f}%')\n",
+        "#    print(f'Probability that the label is true: {true_prob:0.2f}%')\n",
+        "\n",
+        "# get index of hypothesis with highest score\n",
+        "highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])\n",
+        "\n",
+        "# get hypothesis with highest score\n",
+        "highest_hypothesis = hypotheses[highest_index]\n",
+        "\n",
+        "# get highest probability\n",
+        "highest_prob = true_probs[highest_index]\n",
+        "\n",
+        "# print the results\n",
+        "print(f'The hypothesis with the highest score is: \"{highest_hypothesis}\" with a probability of {highest_prob:0.2f}%')"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": []
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": []
     }
   ],
   "metadata": {
@@ -259,8 +692,16 @@
       "name": "python3"
     },
     "language_info": {
+      "codemirror_mode": {
+        "name": "ipython",
+        "version": 3
+      },
+      "file_extension": ".py",
+      "mimetype": "text/x-python",
       "name": "python",
-      "version": "3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n[Clang 13.0.1 ]"
+      "nbconvert_exporter": "python",
+      "pygments_lexer": "ipython3",
+      "version": "3.9.13"
     },
     "vscode": {
       "interpreter": {
-- 
GitLab