{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "aXLlx8vXQlJw" }, "source": [ "# Zero Shot Topic Classification with Transformers\n", "\n", "https://joeddav.github.io/blog/2020/05/29/ZSL.html\n", "\n", "https://colab.research.google.com/github/joeddav/blog/blob/master/_notebooks/2020-05-29-ZSL.ipynb#scrollTo=La_ga8KvSFYd\n", "\n", "https://huggingface.co/spaces/joeddav/zero-shot-demo" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "3kYI_pq3Q1BT" }, "source": [ "## 1. Configuration" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "P_L0rDhZQ6Fn" }, "source": [ "### 1.1 Setup colab environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FsAR4CsB3aUc", "outputId": "e0791012-6858-4ee0-f724-7f33c6985ee8" }, "outputs": [], "source": [ "from psutil import virtual_memory\n", "ram_gb = virtual_memory().total / 1e9\n", "print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n", "\n", "if ram_gb < 20:\n", " print('Not using a high-RAM runtime')\n", "else:\n", " print('You are using a high-RAM runtime!')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "h5MwRwL53aYY", "outputId": "20a93907-e5df-47b1-9172-d1693ef76dc5" }, "outputs": [], "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = \"drive/MyDrive/Classification-EDdA/\"" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "4z78CLYi75kV" }, "source": [ "### 1.2 Import libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bcptSr6o3ac7" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "from transformers import BartForSequenceClassification, BartTokenizer\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Lc1DRh4b7mto" }, "source": [ "## 2. Load datasets" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.1 Download datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ybiJYL0h3ahh", "outputId": "0638f9a2-f9a0-4d96-9760-991ddc5747ca" }, "outputs": [], "source": [ "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/EDdA_dataframe_withContent.tsv\n", "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/training_set.tsv\n", "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/test_set.tsv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset_path = 'EDdA_dataframe_withContent.tsv'\n", "training_set_path = 'training_set.tsv'\n", "test_set_path = 'test_set.tsv'\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LRKJzWmf3pCg", "outputId": "686c3ef4-8267-4266-95af-7193725aadca" }, "outputs": [], "source": [ "df = pd.read_csv(test_set_path, sep=\"\\t\")\n", "\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "column_text = 'contentWithoutClass'\n", "column_class = 'ensemble_domaine_enccre'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 110 }, "id": "RbsHOiJdNYRL", "outputId": "bbdafc35-cf09-4a20-c3c0-901b8adce561" }, "outputs": [], "source": [ "df[column_text].tolist()[0]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Classification" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The approach, proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161), uses a pre-trained MNLI sequence-pair classifier as an out-of-the-box zero-shot text classifier that actually works pretty well. The idea is to take the sequence we're interested in labeling as the \"premise\" and to turn each candidate label into a \"hypothesis.\" If the NLI model predicts that the premise \"entails\" the hypothesis, we take the label to be true. See the code snippet below which demonstrates how easily this can be done with 🤗 Transformers." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load model pretrained on MNLI\n", "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')\n", "model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# pose sequence as a NLI premise and label (politics) as a hypothesis\n", "premise = 'Who are you voting for in 2020?'\n", "hypothesis = 'This text is about politics.'\n", "\n", "# run through model pre-trained on MNLI\n", "input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n", "logits = model(input_ids)[0]\n", "\n", "# we throw away \"neutral\" (dim 1) and take the probability of\n", "# \"entailment\" (2) as the probability of the label being true \n", "entail_contradiction_logits = logits[:,[0,2]]\n", "probs = entail_contradiction_logits.softmax(dim=1)\n", "true_prob = probs[:,1].item() * 100\n", "print(f'Probability that the label is true: {true_prob:0.2f}%')"