From 769ee67ad7b4c137b96ab639c86978af889657e3 Mon Sep 17 00:00:00 2001 From: Ludovic Moncla <moncla.ludovic@gmail.com> Date: Sat, 14 Jan 2023 17:38:25 +0100 Subject: [PATCH] Update Classification_Zero-Shot-Learning.ipynb --- .../Classification_Zero-Shot-Learning.ipynb | 473 +++++++++++++++++- 1 file changed, 457 insertions(+), 16 deletions(-) diff --git a/notebooks/Classification_Zero-Shot-Learning.ipynb b/notebooks/Classification_Zero-Shot-Learning.ipynb index 78e4214..2869c84 100644 --- a/notebooks/Classification_Zero-Shot-Learning.ipynb +++ b/notebooks/Classification_Zero-Shot-Learning.ipynb @@ -80,7 +80,7 @@ "metadata": {}, "outputs": [], "source": [ - "path = \"drive/MyDrive/Classification-EDdA/\"" + "output_path = \"drive/MyDrive/Classification-EDdA/\"" ] }, { @@ -95,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "id": "bcptSr6o3ac7" }, @@ -143,19 +143,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "dataset_path = 'EDdA_dataframe_withContent.tsv'\n", "training_set_path = 'training_set.tsv'\n", "test_set_path = 'test_set.tsv'\n", - "\n" + "\n", + "input_path = '/Users/lmoncla/Nextcloud-LIRIS/GEODE/GEODE - Partage consortium/Classification domaines EDdA/datasets/'\n", + "#input_path = ''\n", + "output_path = ''" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -163,26 +166,296 @@ "id": "LRKJzWmf3pCg", "outputId": "686c3ef4-8267-4266-95af-7193725aadca" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>volume</th>\n", + " <th>numero</th>\n", + " <th>head</th>\n", + " <th>normClass</th>\n", + " <th>classEDdA</th>\n", + " <th>author</th>\n", + " <th>id_enccre</th>\n", + " <th>domaine_enccre</th>\n", + " <th>ensemble_domaine_enccre</th>\n", + " <th>content</th>\n", + " <th>contentWithoutClass</th>\n", + " <th>firstParagraph</th>\n", + " <th>nb_word</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>11</td>\n", + " <td>2973</td>\n", + " <td>ORNIS</td>\n", + " <td>Commerce</td>\n", + " <td>Comm.</td>\n", + " <td>unsigned</td>\n", + " <td>v11-1767-0</td>\n", + " <td>commerce</td>\n", + " <td>Commerce</td>\n", + " <td>ORNIS, s. m. toile des Indes, (Comm.) sortes d...</td>\n", + " <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n", + " <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n", + " <td>45</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>3</td>\n", + " <td>3525</td>\n", + " <td>COMPRENDRE</td>\n", + " <td>Philosophie</td>\n", + " <td>terme de Philosophie,</td>\n", + " <td>Diderot</td>\n", + " <td>v3-1722-0</td>\n", + " <td>NaN</td>\n", + " <td>NaN</td>\n", + " <td>* COMPRENDRE, v. act. terme de Philosophie,\\nc...</td>\n", + " <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n", + " <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n", + " <td>92</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>1</td>\n", + " <td>2560</td>\n", + " <td>ANCRE</td>\n", + " <td>Marine</td>\n", + " <td>Marine</td>\n", + " <td>d'Alembert & Diderot</td>\n", + " <td>v1-1865-0</td>\n", + " <td>marine</td>\n", + " <td>Marine</td>\n", + " <td>ANCRE, s. f. (Marine.) est un instrument de fe...</td>\n", + " <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n", + " <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n", + " <td>3327</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>16</td>\n", + " <td>4241</td>\n", + " <td>VAKEBARO</td>\n", + " <td>Géographie moderne</td>\n", + " <td>Géog. mod.</td>\n", + " <td>unsigned</td>\n", + " <td>v16-2587-0</td>\n", + " <td>géographie</td>\n", + " <td>Géographie</td>\n", + " <td>VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es...</td>\n", + " <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n", + " <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n", + " <td>34</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>8</td>\n", + " <td>3281</td>\n", + " <td>INSPECTEUR</td>\n", + " <td>Histoire ancienne</td>\n", + " <td>Hist. anc.</td>\n", + " <td>unsigned</td>\n", + " <td>v8-2533-0</td>\n", + " <td>histoire</td>\n", + " <td>Histoire</td>\n", + " <td>INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...</td>\n", + " <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n", + " <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n", + " <td>102</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " volume numero head normClass classEDdA \\\n", + "0 11 2973 ORNIS Commerce Comm. \n", + "1 3 3525 COMPRENDRE Philosophie terme de Philosophie, \n", + "2 1 2560 ANCRE Marine Marine \n", + "3 16 4241 VAKEBARO Géographie moderne Géog. mod. \n", + "4 8 3281 INSPECTEUR Histoire ancienne Hist. anc. \n", + "\n", + " author id_enccre domaine_enccre ensemble_domaine_enccre \\\n", + "0 unsigned v11-1767-0 commerce Commerce \n", + "1 Diderot v3-1722-0 NaN NaN \n", + "2 d'Alembert & Diderot v1-1865-0 marine Marine \n", + "3 unsigned v16-2587-0 géographie Géographie \n", + "4 unsigned v8-2533-0 histoire Histoire \n", + "\n", + " content \\\n", + "0 ORNIS, s. m. toile des Indes, (Comm.) sortes d... \n", + "1 * COMPRENDRE, v. act. terme de Philosophie,\\nc... \n", + "2 ANCRE, s. f. (Marine.) est un instrument de fe... \n", + "3 VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es... \n", + "4 INSPECTEUR, s. m. inspector ; (Hist. anc.) cel... \n", + "\n", + " contentWithoutClass \\\n", + "0 ORNIS, s. m. toile des Indes, () sortes de\\nto... \n", + "1 * COMPRENDRE, v. act. \\nc'est appercevoir la l... \n", + "2 ANCRE, s. f. (.) est un instrument de fer\\nABC... \n", + "3 VAKEBARO, () vallée du royaume\\nd'Espagne dans... \n", + "4 INSPECTEUR, s. m. inspector ; () celui \\nà qui... \n", + "\n", + " firstParagraph nb_word \n", + "0 ORNIS, s. m. toile des Indes, () sortes de\\nto... 45 \n", + "1 * COMPRENDRE, v. act. \\nc'est appercevoir la l... 92 \n", + "2 ANCRE, s. f. (.) est un instrument de fer\\nABC... 3327 \n", + "3 VAKEBARO, () vallée du royaume\\nd'Espagne dans... 34 \n", + "4 INSPECTEUR, s. m. inspector ; () celui \\nà qui... 102 " + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "df = pd.read_csv(test_set_path, sep=\"\\t\")\n", - "\n", + "df = pd.read_csv(input_path + test_set_path, sep=\"\\t\")\n", "df.head()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(15854, 13)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ - "column_text = 'contentWithoutClass'\n", + "#column_text = 'contentWithoutClass'\n", + "column_text = 'content'\n", "column_class = 'ensemble_domaine_enccre'" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "df = df.dropna(subset=[column_text, column_class]).reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(13441, 13)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Commerce',\n", + " 'Marine',\n", + " 'Géographie',\n", + " 'Histoire',\n", + " 'Belles-lettres - Poésie',\n", + " 'Economie domestique',\n", + " 'Droit - Jurisprudence',\n", + " 'Médecine - Chirurgie',\n", + " 'Militaire (Art) - Guerre - Arme',\n", + " 'Beaux-arts',\n", + " 'Antiquité',\n", + " 'Histoire naturelle',\n", + " 'Grammaire',\n", + " 'Philosophie',\n", + " 'Arts et métiers',\n", + " 'Pharmacie',\n", + " 'Religion',\n", + " 'Pêche',\n", + " 'Anatomie',\n", + " 'Architecture',\n", + " 'Musique',\n", + " 'Jeu',\n", + " 'Caractères',\n", + " 'Métiers',\n", + " 'Physique - [Sciences physico-mathématiques]',\n", + " 'Maréchage - Manège',\n", + " 'Chimie',\n", + " 'Blason',\n", + " 'Chasse',\n", + " 'Mathématiques',\n", + " 'Médailles',\n", + " 'Superstition',\n", + " 'Agriculture - Economie rustique',\n", + " 'Mesure',\n", + " 'Monnaie',\n", + " 'Minéralogie',\n", + " 'Politique',\n", + " 'Spectacle']" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "classes = df[column_class].unique().tolist()\n", + "classes" + ] + }, + { + "cell_type": "code", + "execution_count": 35, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -191,7 +464,18 @@ "id": "RbsHOiJdNYRL", "outputId": "bbdafc35-cf09-4a20-c3c0-901b8adce561" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "\"ORNIS, s. m. toile des Indes, (Comm.) sortes de\\ntoiles de coton ou de mousseline, qui se font a Brampour ville de l'Indoustan, entre Surate & Agra. Ces\\ntoiles sont par bandes, moitié coton & moitié or &\\nargent. Il y en a depuis quinze jusqu'à vingt aunes.\"" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "df[column_text].tolist()[0]" ] @@ -214,9 +498,80 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 30, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e5a45c55993f47019fbdc0aceda84def", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/899k [00:00<?, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a7285f3fe7154920a0bb05fdb921d6f9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/456k [00:00<?, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0cb38fe0e7934c49bbce246e88dd6e53", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/26.0 [00:00<?, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2c9dd205446a4b25848f1f06beeac8ae", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/1.15k [00:00<?, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "032334b5310d4588993b7d85329d916c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/1.63G [00:00<?, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "# load model pretrained on MNLI\n", "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n", @@ -229,6 +584,13 @@ "metadata": {}, "outputs": [], "source": [ + "''' \n", + "## Example from: https://joeddav.github.io/blog/2020/05/29/ZSL.html\n", + "\n", + "# load model pretrained on MNLI\n", + "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n", + "model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')\n", + "\n", "# pose sequence as a NLI premise and label (politics) as a hypothesis\n", "premise = 'Who are you voting for in 2020?'\n", "hypothesis = 'This text is about politics.'\n", @@ -242,8 +604,79 @@ "entail_contradiction_logits = logits[:,[0,2]]\n", "probs = entail_contradiction_logits.softmax(dim=1)\n", "true_prob = probs[:,1].item() * 100\n", - "print(f'Probability that the label is true: {true_prob:0.2f}%')" + "print(f'Probability that the label is true: {true_prob:0.2f}%')\n", + "'''" ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The hypothesis with the highest score is: \"Commerce\" with a probability of 70.05%\n" + ] + } + ], + "source": [ + "# pose sequence as a NLI premise and label (politics) as a hypothesis\n", + "premise = df[column_text].tolist()[0]\n", + "#hypothesis = 'This text is about politics.'\n", + "hypotheses = classes\n", + "\n", + "# list to store the true probability of each hypothesis\n", + "true_probs = []\n", + "\n", + "# loop through hypotheses\n", + "for hypothesis in hypotheses:\n", + "\n", + " # run through model pre-trained on MNLI\n", + " input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n", + " logits = model(input_ids)[0]\n", + "\n", + " # we throw away \"neutral\" (dim 1) and take the probability of\n", + " # \"entailment\" (2) as the probability of the label being true \n", + " entail_contradiction_logits = logits[:,[0,2]]\n", + " probs = entail_contradiction_logits.softmax(dim=1)\n", + " true_prob = probs[:,1].item() * 100\n", + "\n", + " # append true probability to list\n", + " true_probs.append(true_prob)\n", + "\n", + "# print the true probability for each hypothesis\n", + "#for i, hypothesis in enumerate(hypotheses):\n", + "# print(f'Probability that hypothesis \"{hypothesis}\" is true: {true_probs[i]:0.2f}%')\n", + "# print(f'Probability that the label is true: {true_prob:0.2f}%')\n", + "\n", + "# get index of hypothesis with highest score\n", + "highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])\n", + "\n", + "# get hypothesis with highest score\n", + "highest_hypothesis = hypotheses[highest_index]\n", + "\n", + "# get highest probability\n", + "highest_prob = true_probs[highest_index]\n", + "\n", + "# print the results\n", + "print(f'The hypothesis with the highest score is: \"{highest_hypothesis}\" with a probability of {highest_prob:0.2f}%')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -259,8 +692,16 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n[Clang 13.0.1 ]" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" }, "vscode": { "interpreter": { -- GitLab