Skip to content
Snippets Groups Projects
Classification_Zero-Shot-Learning.ipynb 23.8 KiB
Newer Older
{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "aXLlx8vXQlJw"
      },
      "source": [
        "# Zero Shot Topic Classification with Transformers\n",
        "\n",
        "https://joeddav.github.io/blog/2020/05/29/ZSL.html\n",
        "\n",
        "https://colab.research.google.com/github/joeddav/blog/blob/master/_notebooks/2020-05-29-ZSL.ipynb#scrollTo=La_ga8KvSFYd\n",
        "\n",
        "https://huggingface.co/spaces/joeddav/zero-shot-demo"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "3kYI_pq3Q1BT"
      },
      "source": [
        "## 1. Configuration"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "P_L0rDhZQ6Fn"
      },
      "source": [
        "### 1.1 Setup colab environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FsAR4CsB3aUc",
        "outputId": "e0791012-6858-4ee0-f724-7f33c6985ee8"
      },
      "outputs": [],
      "source": [
        "from psutil import virtual_memory\n",
        "ram_gb = virtual_memory().total / 1e9\n",
        "print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n",
        "\n",
        "if ram_gb < 20:\n",
        "  print('Not using a high-RAM runtime')\n",
        "else:\n",
        "  print('You are using a high-RAM runtime!')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h5MwRwL53aYY",
        "outputId": "20a93907-e5df-47b1-9172-d1693ef76dc5"
      },
      "outputs": [],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "output_path = \"drive/MyDrive/Classification-EDdA/\""
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "4z78CLYi75kV"
      },
      "source": [
        "### 1.2 Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "bcptSr6o3ac7"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "\n",
        "from transformers import BartForSequenceClassification, BartTokenizer\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "Lc1DRh4b7mto"
      },
      "source": [
        "## 2. Load datasets"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "#### 2.1 Download datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ybiJYL0h3ahh",
        "outputId": "0638f9a2-f9a0-4d96-9760-991ddc5747ca"
      },
      "outputs": [],
      "source": [
        "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/EDdA_dataframe_withContent.tsv\n",
        "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/training_set.tsv\n",
        "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/test_set.tsv"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [],
      "source": [
        "dataset_path = 'EDdA_dataframe_withContent.tsv'\n",
        "training_set_path = 'training_set.tsv'\n",
        "test_set_path = 'test_set.tsv'\n",
        "\n",
        "input_path = '/Users/lmoncla/Nextcloud-LIRIS/GEODE/GEODE - Partage consortium/Classification domaines EDdA/datasets/'\n",
        "#input_path = ''\n",
        "output_path = ''"
      "execution_count": 18,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LRKJzWmf3pCg",
        "outputId": "686c3ef4-8267-4266-95af-7193725aadca"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>volume</th>\n",
              "      <th>numero</th>\n",
              "      <th>head</th>\n",
              "      <th>normClass</th>\n",
              "      <th>classEDdA</th>\n",
              "      <th>author</th>\n",
              "      <th>id_enccre</th>\n",
              "      <th>domaine_enccre</th>\n",
              "      <th>ensemble_domaine_enccre</th>\n",
              "      <th>content</th>\n",
              "      <th>contentWithoutClass</th>\n",
              "      <th>firstParagraph</th>\n",
              "      <th>nb_word</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>11</td>\n",
              "      <td>2973</td>\n",
              "      <td>ORNIS</td>\n",
              "      <td>Commerce</td>\n",
              "      <td>Comm.</td>\n",
              "      <td>unsigned</td>\n",
              "      <td>v11-1767-0</td>\n",
              "      <td>commerce</td>\n",
              "      <td>Commerce</td>\n",
              "      <td>ORNIS, s. m. toile des Indes, (Comm.) sortes d...</td>\n",
              "      <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n",
              "      <td>ORNIS, s. m. toile des Indes, () sortes de\\nto...</td>\n",
              "      <td>45</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>3</td>\n",
              "      <td>3525</td>\n",
              "      <td>COMPRENDRE</td>\n",
              "      <td>Philosophie</td>\n",
              "      <td>terme de Philosophie,</td>\n",
              "      <td>Diderot</td>\n",
              "      <td>v3-1722-0</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>* COMPRENDRE, v. act. terme de Philosophie,\\nc...</td>\n",
              "      <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n",
              "      <td>* COMPRENDRE, v. act. \\nc'est appercevoir la l...</td>\n",
              "      <td>92</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>1</td>\n",
              "      <td>2560</td>\n",
              "      <td>ANCRE</td>\n",
              "      <td>Marine</td>\n",
              "      <td>Marine</td>\n",
              "      <td>d'Alembert &amp; Diderot</td>\n",
              "      <td>v1-1865-0</td>\n",
              "      <td>marine</td>\n",
              "      <td>Marine</td>\n",
              "      <td>ANCRE, s. f. (Marine.) est un instrument de fe...</td>\n",
              "      <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n",
              "      <td>ANCRE, s. f. (.) est un instrument de fer\\nABC...</td>\n",
              "      <td>3327</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>16</td>\n",
              "      <td>4241</td>\n",
              "      <td>VAKEBARO</td>\n",
              "      <td>Géographie moderne</td>\n",
              "      <td>Géog. mod.</td>\n",
              "      <td>unsigned</td>\n",
              "      <td>v16-2587-0</td>\n",
              "      <td>géographie</td>\n",
              "      <td>Géographie</td>\n",
              "      <td>VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es...</td>\n",
              "      <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n",
              "      <td>VAKEBARO, () vallée du royaume\\nd'Espagne dans...</td>\n",
              "      <td>34</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>8</td>\n",
              "      <td>3281</td>\n",
              "      <td>INSPECTEUR</td>\n",
              "      <td>Histoire ancienne</td>\n",
              "      <td>Hist. anc.</td>\n",
              "      <td>unsigned</td>\n",
              "      <td>v8-2533-0</td>\n",
              "      <td>histoire</td>\n",
              "      <td>Histoire</td>\n",
              "      <td>INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...</td>\n",
              "      <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n",
              "      <td>INSPECTEUR, s. m. inspector ; () celui \\nà qui...</td>\n",
              "      <td>102</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "   volume  numero        head           normClass              classEDdA  \\\n",
              "0      11    2973       ORNIS            Commerce                  Comm.   \n",
              "1       3    3525  COMPRENDRE         Philosophie  terme de Philosophie,   \n",
              "2       1    2560       ANCRE              Marine                 Marine   \n",
              "3      16    4241    VAKEBARO  Géographie moderne             Géog. mod.   \n",
              "4       8    3281  INSPECTEUR   Histoire ancienne             Hist. anc.   \n",
              "\n",
              "                 author   id_enccre domaine_enccre ensemble_domaine_enccre  \\\n",
              "0              unsigned  v11-1767-0       commerce                Commerce   \n",
              "1               Diderot   v3-1722-0            NaN                     NaN   \n",
              "2  d'Alembert & Diderot   v1-1865-0         marine                  Marine   \n",
              "3              unsigned  v16-2587-0     géographie              Géographie   \n",
              "4              unsigned   v8-2533-0       histoire                Histoire   \n",
              "\n",
              "                                             content  \\\n",
              "0  ORNIS, s. m. toile des Indes, (Comm.) sortes d...   \n",
              "1  * COMPRENDRE, v. act. terme de Philosophie,\\nc...   \n",
              "2  ANCRE, s. f. (Marine.) est un instrument de fe...   \n",
              "3  VAKEBARO, (Géog. mod.) vallée du royaume\\nd'Es...   \n",
              "4  INSPECTEUR, s. m. inspector ; (Hist. anc.) cel...   \n",
              "\n",
              "                                 contentWithoutClass  \\\n",
              "0  ORNIS, s. m. toile des Indes, () sortes de\\nto...   \n",
              "1  * COMPRENDRE, v. act. \\nc'est appercevoir la l...   \n",
              "2  ANCRE, s. f. (.) est un instrument de fer\\nABC...   \n",
              "3  VAKEBARO, () vallée du royaume\\nd'Espagne dans...   \n",
              "4  INSPECTEUR, s. m. inspector ; () celui \\nà qui...   \n",
              "\n",
              "                                      firstParagraph  nb_word  \n",
              "0  ORNIS, s. m. toile des Indes, () sortes de\\nto...       45  \n",
              "1  * COMPRENDRE, v. act. \\nc'est appercevoir la l...       92  \n",
              "2  ANCRE, s. f. (.) est un instrument de fer\\nABC...     3327  \n",
              "3  VAKEBARO, () vallée du royaume\\nd'Espagne dans...       34  \n",
              "4  INSPECTEUR, s. m. inspector ; () celui \\nà qui...      102  "
            ]
          },
          "execution_count": 18,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
        "df = pd.read_csv(input_path + test_set_path, sep=\"\\t\")\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(15854, 13)"
            ]
          },
          "execution_count": 19,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "df.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {},
      "outputs": [],
      "source": [
        "#column_text = 'contentWithoutClass'\n",
        "column_text = 'content'\n",
        "column_class = 'ensemble_domaine_enccre'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {},
      "outputs": [],
      "source": [
        "df = df.dropna(subset=[column_text, column_class]).reset_index(drop=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(13441, 13)"
            ]
          },
          "execution_count": 23,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "df.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "['Commerce',\n",
              " 'Marine',\n",
              " 'Géographie',\n",
              " 'Histoire',\n",
              " 'Belles-lettres - Poésie',\n",
              " 'Economie domestique',\n",
              " 'Droit - Jurisprudence',\n",
              " 'Médecine - Chirurgie',\n",
              " 'Militaire (Art) - Guerre - Arme',\n",
              " 'Beaux-arts',\n",
              " 'Antiquité',\n",
              " 'Histoire naturelle',\n",
              " 'Grammaire',\n",
              " 'Philosophie',\n",
              " 'Arts et métiers',\n",
              " 'Pharmacie',\n",
              " 'Religion',\n",
              " 'Pêche',\n",
              " 'Anatomie',\n",
              " 'Architecture',\n",
              " 'Musique',\n",
              " 'Jeu',\n",
              " 'Caractères',\n",
              " 'Métiers',\n",
              " 'Physique - [Sciences physico-mathématiques]',\n",
              " 'Maréchage - Manège',\n",
              " 'Chimie',\n",
              " 'Blason',\n",
              " 'Chasse',\n",
              " 'Mathématiques',\n",
              " 'Médailles',\n",
              " 'Superstition',\n",
              " 'Agriculture - Economie rustique',\n",
              " 'Mesure',\n",
              " 'Monnaie',\n",
              " 'Minéralogie',\n",
              " 'Politique',\n",
              " 'Spectacle']"
            ]
          },
          "execution_count": 32,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "classes = df[column_class].unique().tolist()\n",
        "classes"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 3. Classification"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The approach, proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161), uses a pre-trained MNLI sequence-pair classifier as an out-of-the-box zero-shot text classifier that actually works pretty well. The idea is to take the sequence we're interested in labeling as the \"premise\" and to turn each candidate label into a \"hypothesis.\" If the NLI model predicts that the premise \"entails\" the hypothesis, we take the label to be true. See the code snippet below which demonstrates how easily this can be done with 🤗 Transformers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e5a45c55993f47019fbdc0aceda84def",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a7285f3fe7154920a0bb05fdb921d6f9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0cb38fe0e7934c49bbce246e88dd6e53",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2c9dd205446a4b25848f1f06beeac8ae",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading:   0%|          | 0.00/1.15k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "032334b5310d4588993b7d85329d916c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading:   0%|          | 0.00/1.63G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# load model pretrained on MNLI\n",
        "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
        "model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "''' \n",
        "## Example from: https://joeddav.github.io/blog/2020/05/29/ZSL.html\n",
        "\n",
        "# load model pretrained on MNLI\n",
        "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
        "model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')\n",
        "\n",
        "# pose sequence as a NLI premise and label (politics) as a hypothesis\n",
        "premise = 'Who are you voting for in 2020?'\n",
        "hypothesis = 'This text is about politics.'\n",
        "\n",
        "# run through model pre-trained on MNLI\n",
        "input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n",
        "logits = model(input_ids)[0]\n",
        "\n",
        "# we throw away \"neutral\" (dim 1) and take the probability of\n",
        "# \"entailment\" (2) as the probability of the label being true \n",
        "entail_contradiction_logits = logits[:,[0,2]]\n",
        "probs = entail_contradiction_logits.softmax(dim=1)\n",
        "true_prob = probs[:,1].item() * 100\n",
        "print(f'Probability that the label is true: {true_prob:0.2f}%')\n",
        "'''"
      "execution_count": 37,
        "def zero_shot_prediction(premise, hypotheses):\n",
        "    # list to store the true probability of each hypothesis\n",
        "    true_probs = []\n",
        "    # loop through hypotheses\n",
        "    for hypothesis in hypotheses:\n",
        "        # run through model pre-trained on MNLI\n",
        "        input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n",
        "        logits = model(input_ids)[0]\n",
        "        # we throw away \"neutral\" (dim 1) and take the probability of\n",
        "        # \"entailment\" (2) as the probability of the label being true \n",
        "        entail_contradiction_logits = logits[:,[0,2]]\n",
        "        probs = entail_contradiction_logits.softmax(dim=1)\n",
        "        true_prob = probs[:,1].item() * 100\n",
        "        # append true probability to list\n",
        "        true_probs.append(true_prob)\n",
        "    return true_probs\n",
        "def get_highest_score(true_probs, hypotheses):\n",
        "    # print the true probability for each hypothesis\n",
        "    #for i, hypothesis in enumerate(hypotheses):\n",
        "    #    print(f'Probability that hypothesis \"{hypothesis}\" is true: {true_probs[i]:0.2f}%')\n",
        "    #    print(f'Probability that the label is true: {true_prob:0.2f}%')\n",
        "    # get index of hypothesis with highest score\n",
        "    highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])\n",
        "    # get hypothesis with highest score\n",
        "    highest_hypothesis = hypotheses[highest_index]\n",
        "    # get highest probability\n",
        "    highest_prob = true_probs[highest_index]\n",
        "    \n",
        "    return (highest_hypothesis, highest_prob)\n",
        "    "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "df[column_text].tolist()[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 38,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The hypothesis with the highest score is: \"Commerce\" with a probability of 70.05%\n"
          ]
        }
      ],
      "source": [
        "premise = df[column_text].tolist()[0]\n",
        "\n",
        "true_probs = zero_shot_prediction(premise, classes)\n",
        "highest_score = get_highest_score(true_probs, classes)\n",
        "\n",
        "# print the results\n",
        "print(f'The hypothesis with the highest score is: \"{highest_score[0]}\" with a probability of {highest_score[1]:0.2f}%')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "machine_shape": "hm",
      "name": "EDdA-Classification_Clustering.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "geode-classification-py39",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.13"
    },
    "vscode": {
      "interpreter": {
        "hash": "16fac9c2d845f8e1f8c6fffffe3d3a0def61c7e42da17a08d00f279ad4dea797"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}