diff --git a/notebooks/Classification_Zero-Shot-Learning.ipynb b/notebooks/Classification_Zero-Shot-Learning.ipynb index 2869c849d24c9a7949ecb160ff9b445a13069ec0..1ea10781cb40a4ed28467330fbc236a9e630951b 100644 --- a/notebooks/Classification_Zero-Shot-Learning.ipynb +++ b/notebooks/Classification_Zero-Shot-Learning.ipynb @@ -453,33 +453,6 @@ "classes" ] }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 110 - }, - "id": "RbsHOiJdNYRL", - "outputId": "bbdafc35-cf09-4a20-c3c0-901b8adce561" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\"ORNIS, s. m. toile des Indes, (Comm.) sortes de\\ntoiles de coton ou de mousseline, qui se font a Brampour ville de l'Indoustan, entre Surate & Agra. Ces\\ntoiles sont par bandes, moitié coton & moitié or &\\nargent. Il y en a depuis quinze jusqu'à vingt aunes.\"" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df[column_text].tolist()[0]" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -610,58 +583,50 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 37, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The hypothesis with the highest score is: \"Commerce\" with a probability of 70.05%\n" - ] - } - ], + "outputs": [], "source": [ - "# pose sequence as a NLI premise and label (politics) as a hypothesis\n", - "premise = df[column_text].tolist()[0]\n", - "#hypothesis = 'This text is about politics.'\n", - "hypotheses = classes\n", + "def zero_shot_prediction(premise, hypotheses):\n", + " # list to store the true probability of each hypothesis\n", + " true_probs = []\n", "\n", - "# list to store the true probability of each hypothesis\n", - "true_probs = []\n", + " # loop through hypotheses\n", + " for hypothesis in hypotheses:\n", "\n", - "# loop through hypotheses\n", - "for hypothesis in hypotheses:\n", + " # run through model pre-trained on MNLI\n", + " input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n", + " logits = model(input_ids)[0]\n", "\n", - " # run through model pre-trained on MNLI\n", - " input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')\n", - " logits = model(input_ids)[0]\n", + " # we throw away \"neutral\" (dim 1) and take the probability of\n", + " # \"entailment\" (2) as the probability of the label being true \n", + " entail_contradiction_logits = logits[:,[0,2]]\n", + " probs = entail_contradiction_logits.softmax(dim=1)\n", + " true_prob = probs[:,1].item() * 100\n", "\n", - " # we throw away \"neutral\" (dim 1) and take the probability of\n", - " # \"entailment\" (2) as the probability of the label being true \n", - " entail_contradiction_logits = logits[:,[0,2]]\n", - " probs = entail_contradiction_logits.softmax(dim=1)\n", - " true_prob = probs[:,1].item() * 100\n", + " # append true probability to list\n", + " true_probs.append(true_prob)\n", "\n", - " # append true probability to list\n", - " true_probs.append(true_prob)\n", + " return true_probs\n", "\n", - "# print the true probability for each hypothesis\n", - "#for i, hypothesis in enumerate(hypotheses):\n", - "# print(f'Probability that hypothesis \"{hypothesis}\" is true: {true_probs[i]:0.2f}%')\n", - "# print(f'Probability that the label is true: {true_prob:0.2f}%')\n", "\n", - "# get index of hypothesis with highest score\n", - "highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])\n", + "def get_highest_score(true_probs, hypotheses):\n", + " # print the true probability for each hypothesis\n", + " #for i, hypothesis in enumerate(hypotheses):\n", + " # print(f'Probability that hypothesis \"{hypothesis}\" is true: {true_probs[i]:0.2f}%')\n", + " # print(f'Probability that the label is true: {true_prob:0.2f}%')\n", "\n", - "# get hypothesis with highest score\n", - "highest_hypothesis = hypotheses[highest_index]\n", + " # get index of hypothesis with highest score\n", + " highest_index = max(range(len(true_probs)), key=lambda i: true_probs[i])\n", "\n", - "# get highest probability\n", - "highest_prob = true_probs[highest_index]\n", + " # get hypothesis with highest score\n", + " highest_hypothesis = hypotheses[highest_index]\n", "\n", - "# print the results\n", - "print(f'The hypothesis with the highest score is: \"{highest_hypothesis}\" with a probability of {highest_prob:0.2f}%')" + " # get highest probability\n", + " highest_prob = true_probs[highest_index]\n", + " \n", + " return (highest_hypothesis, highest_prob)\n", + " " ] }, { @@ -669,7 +634,32 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "df[column_text].tolist()[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The hypothesis with the highest score is: \"Commerce\" with a probability of 70.05%\n" + ] + } + ], + "source": [ + "premise = df[column_text].tolist()[0]\n", + "\n", + "true_probs = zero_shot_prediction(premise, classes)\n", + "highest_score = get_highest_score(true_probs, classes)\n", + "\n", + "# print the results\n", + "print(f'The hypothesis with the highest score is: \"{highest_score[0]}\" with a probability of {highest_score[1]:0.2f}%')" + ] }, { "cell_type": "code",