Skip to content
Snippets Groups Projects
Classification_Zero-Shot-Learning.ipynb 12.6 KiB
Newer Older
{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "aXLlx8vXQlJw"
      },
      "source": [
        "# Zero Shot Topic Classification with Transformers\n",
        "\n",
        "https://joeddav.github.io/blog/2020/05/29/ZSL.html\n",
        "\n",
        "https://colab.research.google.com/github/joeddav/blog/blob/master/_notebooks/2020-05-29-ZSL.ipynb#scrollTo=La_ga8KvSFYd\n",
        "\n",
        "https://huggingface.co/spaces/joeddav/zero-shot-demo"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "3kYI_pq3Q1BT"
      },
      "source": [
        "## 1. Configuration"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "P_L0rDhZQ6Fn"
      },
      "source": [
        "### 1.1 Setup colab environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FsAR4CsB3aUc",
        "outputId": "e0791012-6858-4ee0-f724-7f33c6985ee8"
      },
      "outputs": [],
      "source": [
        "from psutil import virtual_memory\n",
        "ram_gb = virtual_memory().total / 1e9\n",
        "print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n",
        "\n",
        "if ram_gb < 20:\n",
        "  print('Not using a high-RAM runtime')\n",
        "else:\n",
        "  print('You are using a high-RAM runtime!')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h5MwRwL53aYY",
        "outputId": "20a93907-e5df-47b1-9172-d1693ef76dc5"
      },
      "outputs": [],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "output_path = \"drive/MyDrive/Classification-EDdA/\""
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "4z78CLYi75kV"
      },
      "source": [
        "### 1.2 Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bcptSr6o3ac7"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "from tqdm import tqdm\n",
        "from transformers import BartForSequenceClassification, BartTokenizer\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "Lc1DRh4b7mto"
      },
      "source": [
        "## 2. Load datasets"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "#### 2.1 Download datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ybiJYL0h3ahh",
        "outputId": "0638f9a2-f9a0-4d96-9760-991ddc5747ca"
      },
      "outputs": [],
      "source": [
        "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/EDdA_dataframe_withContent.tsv\n",
        "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/training_set.tsv\n",
        "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/test_set.tsv"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "dataset_path = 'EDdA_dataframe_withContent.tsv'\n",
        "training_set_path = 'training_set.tsv'\n",
        "test_set_path = 'test_set.tsv'\n",
        "\n",
        "input_path = '/Users/lmoncla/Nextcloud-LIRIS/GEODE/GEODE - Partage consortium/Classification domaines EDdA/datasets/'\n",
        "#input_path = ''\n",
        "output_path = ''"
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LRKJzWmf3pCg",
        "outputId": "686c3ef4-8267-4266-95af-7193725aadca"
      },
        "df = pd.read_csv(input_path + test_set_path, sep=\"\\t\")\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "df.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "#column_text = 'contentWithoutClass'\n",
        "column_text = 'content'\n",
        "column_class = 'ensemble_domaine_enccre'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "df = df.dropna(subset=[column_text, column_class]).reset_index(drop=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "df.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "classes = df[column_class].unique().tolist()\n",
        "classes"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 3. Classification"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The approach, proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161), uses a pre-trained MNLI sequence-pair classifier as an out-of-the-box zero-shot text classifier that actually works pretty well. The idea is to take the sequence we're interested in labeling as the \"premise\" and to turn each candidate label into a \"hypothesis.\" If the NLI model predicts that the premise \"entails\" the hypothesis, we take the label to be true. See the code snippet below which demonstrates how easily this can be done with 🤗 Transformers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "# load model pretrained on MNLI\n",
        "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
        "model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "''' \n",
        "## Example from: https://joeddav.github.io/blog/2020/05/29/ZSL.html\n",
        "\n",
        "# load model pretrained on MNLI\n",
        "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n",
        "model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')\n",
        "\n",
        "# pose sequence as a NLI premise and label (politics) as a hypothesis\n",
        "premise = 'Who are you voting for in 2020?'\n",
        "hypothesis = 'This text is about politics.'\n",
        "\n",
        "# run through model pre-trained on MNLI\n",
        "input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n",
        "logits = model(input_ids)[0]\n",
        "\n",
        "# we throw away \"neutral\" (dim 1) and take the probability of\n",
        "# \"entailment\" (2) as the probability of the label being true \n",
        "entail_contradiction_logits = logits[:,[0,2]]\n",
        "probs = entail_contradiction_logits.softmax(dim=1)\n",
        "true_prob = probs[:,1].item() * 100\n",
        "print(f'Probability that the label is true: {true_prob:0.2f}%')\n",
        "'''"
      "execution_count": null,
        "def zero_shot_prediction(premise, hypotheses):\n",
        "    # list to store the true probability of each hypothesis\n",
        "    true_probs = []\n",
        "    # loop through hypotheses\n",
        "    for hypothesis in hypotheses:\n",
        "        # run through model pre-trained on MNLI\n",
        "        input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n",
        "        logits = model(input_ids)[0]\n",
        "        # we throw away \"neutral\" (dim 1) and take the probability of\n",
        "        # \"entailment\" (2) as the probability of the label being true \n",
        "        entail_contradiction_logits = logits[:,[0,2]]\n",
        "        probs = entail_contradiction_logits.softmax(dim=1)\n",
        "        true_prob = probs[:,1].item() * 100\n",
        "        # append true probability to list\n",
        "        true_probs.append(true_prob)\n",
        "    return true_probs\n",
        "def get_highest_score(true_probs, hypotheses):\n",
        "    # get index of hypothesis with highest score\n",
        "    highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])\n",
        "    # get hypothesis with highest score\n",
        "    highest_hypothesis = hypotheses[highest_index]\n",
        "    # get highest probability\n",
        "    highest_prob = true_probs[highest_index]\n",
        "    \n",
        "    return (highest_hypothesis, highest_prob)\n",
        "\n",
        "\n",
        "def get_sorted_scores(true_probs, hypotheses):\n",
        "\n",
        "   # sort hypotheses based on their scores\n",
        "    sorted_hypotheses = [hypothesis for _, hypothesis in sorted(zip(true_probs, hypotheses), reverse=True)]\n",
        "\n",
        "    # sort scores\n",
        "    sorted_scores = sorted(true_probs, reverse=True)\n",
        "    \n",
        "    return list(zip(sorted_hypotheses, sorted_scores))\n",
      "execution_count": null,
      "source": [
        "df[column_text].tolist()[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "premise = df[column_text].tolist()[0]\n",
        "\n",
        "true_probs = zero_shot_prediction(premise, classes)\n",
        "highest_score = get_highest_score(true_probs, classes)\n",
        "\n",
        "# print the results\n",
        "print(f'The hypothesis with the highest score is: \"{highest_score[0]}\" with a probability of {highest_score[1]:0.2f}%')\n",
        "\n",
        "\n",
        "probs = get_sorted_scores(true_probs, classes)\n",
        "probs\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "probs[0][0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "pred_labels = []\n",
        "prob_labels = []\n",
        "\n",
        "for content in tqdm(df[column_text].tolist()):\n",
        "\n",
        "    true_probs = zero_shot_prediction(content[:1024], classes)\n",
        "    \n",
        "    pred_labels.append(get_highest_score(true_probs, classes)[0])\n",
        "    prob_labels.append(get_sorted_scores(true_probs, classes))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "source": [
        "true_labels = df[column_class].tolist()"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "machine_shape": "hm",
      "name": "EDdA-Classification_Clustering.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "geode-classification-py39",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.13"
    },
    "vscode": {
      "interpreter": {
        "hash": "16fac9c2d845f8e1f8c6fffffe3d3a0def61c7e42da17a08d00f279ad4dea797"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}