From 4731074b4a508b0da5fbf83704e845214b6e2123 Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <moncla.ludovic@gmail.com>
Date: Tue, 21 Mar 2023 13:18:00 +0100
Subject: [PATCH] Update Predict.ipynb

---
 notebooks/Predict.ipynb | 225 ++++++++++++++++++++++++++++++++++++----
 1 file changed, 205 insertions(+), 20 deletions(-)

diff --git a/notebooks/Predict.ipynb b/notebooks/Predict.ipynb
index 0dcd667..ac3cd1d 100644
--- a/notebooks/Predict.ipynb
+++ b/notebooks/Predict.ipynb
@@ -98,7 +98,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 40,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
@@ -106,7 +106,15 @@
         "id": "dPOU-Efhf4ui",
         "outputId": "0bb7fd0e-e2fb-4477-e5f7-b408d0a1ced7"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "We will use the GPU\n"
+          ]
+        }
+      ],
       "source": [
         "import torch\n",
         "\n",
@@ -138,7 +146,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 41,
       "metadata": {
         "id": "SkErnwgMMbRj"
       },
@@ -181,7 +189,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 42,
       "metadata": {
         "id": "M2awiee1r0zV"
       },
@@ -203,7 +211,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 43,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/",
@@ -212,7 +220,119 @@
         "id": "erjPU3y8r0zW",
         "outputId": "e2b4a39d-a72b-4e7a-8b26-e709eb983df3"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "data": {
+            "text/html": [
+              "<div>\n",
+              "<style scoped>\n",
+              "    .dataframe tbody tr th:only-of-type {\n",
+              "        vertical-align: middle;\n",
+              "    }\n",
+              "\n",
+              "    .dataframe tbody tr th {\n",
+              "        vertical-align: top;\n",
+              "    }\n",
+              "\n",
+              "    .dataframe thead th {\n",
+              "        text-align: right;\n",
+              "    }\n",
+              "</style>\n",
+              "<table border=\"1\" class=\"dataframe\">\n",
+              "  <thead>\n",
+              "    <tr style=\"text-align: right;\">\n",
+              "      <th></th>\n",
+              "      <th>uid</th>\n",
+              "      <th>lge-volume</th>\n",
+              "      <th>lge-numero</th>\n",
+              "      <th>lge-head</th>\n",
+              "      <th>lge-page</th>\n",
+              "      <th>lge-id</th>\n",
+              "      <th>lge-content</th>\n",
+              "      <th>lge-nbWords</th>\n",
+              "    </tr>\n",
+              "  </thead>\n",
+              "  <tbody>\n",
+              "    <tr>\n",
+              "      <th>0</th>\n",
+              "      <td>lge_1_a-0</td>\n",
+              "      <td>1</td>\n",
+              "      <td>1</td>\n",
+              "      <td>A</td>\n",
+              "      <td>0</td>\n",
+              "      <td>a-0</td>\n",
+              "      <td>A(Ling.). Son vocal et première lettre de notr...</td>\n",
+              "      <td>1761.0</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>1</th>\n",
+              "      <td>lge_1_a-1</td>\n",
+              "      <td>1</td>\n",
+              "      <td>2</td>\n",
+              "      <td>A</td>\n",
+              "      <td>1</td>\n",
+              "      <td>a-1</td>\n",
+              "      <td>A(Paléogr.). C’est à l’alphabet phénicien, on ...</td>\n",
+              "      <td>839.0</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>2</th>\n",
+              "      <td>lge_1_a-2</td>\n",
+              "      <td>1</td>\n",
+              "      <td>3</td>\n",
+              "      <td>A</td>\n",
+              "      <td>4</td>\n",
+              "      <td>a-2</td>\n",
+              "      <td>A(Log.). Cette voyelle désigne les proposition...</td>\n",
+              "      <td>56.0</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>3</th>\n",
+              "      <td>lge_1_a-3</td>\n",
+              "      <td>1</td>\n",
+              "      <td>4</td>\n",
+              "      <td>A</td>\n",
+              "      <td>4</td>\n",
+              "      <td>a-3</td>\n",
+              "      <td>A(Mus.). La lettre a est employée par les musi...</td>\n",
+              "      <td>267.0</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>4</th>\n",
+              "      <td>lge_1_a-4</td>\n",
+              "      <td>1</td>\n",
+              "      <td>5</td>\n",
+              "      <td>A</td>\n",
+              "      <td>4</td>\n",
+              "      <td>a-4</td>\n",
+              "      <td>A(Numis.). Dans la numismatique grecque, la le...</td>\n",
+              "      <td>67.0</td>\n",
+              "    </tr>\n",
+              "  </tbody>\n",
+              "</table>\n",
+              "</div>"
+            ],
+            "text/plain": [
+              "         uid  lge-volume  lge-numero lge-head  lge-page lge-id  \\\n",
+              "0  lge_1_a-0           1           1        A         0    a-0   \n",
+              "1  lge_1_a-1           1           2        A         1    a-1   \n",
+              "2  lge_1_a-2           1           3        A         4    a-2   \n",
+              "3  lge_1_a-3           1           4        A         4    a-3   \n",
+              "4  lge_1_a-4           1           5        A         4    a-4   \n",
+              "\n",
+              "                                         lge-content  lge-nbWords  \n",
+              "0  A(Ling.). Son vocal et première lettre de notr...       1761.0  \n",
+              "1  A(Paléogr.). C’est à l’alphabet phénicien, on ...        839.0  \n",
+              "2  A(Log.). Cette voyelle désigne les proposition...         56.0  \n",
+              "3  A(Mus.). La lettre a est employée par les musi...        267.0  \n",
+              "4  A(Numis.). Dans la numismatique grecque, la le...         67.0  "
+            ]
+          },
+          "execution_count": 43,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
       "source": [
         "df = pd.read_csv(path + filepath, sep=\"\\t\")\n",
         "df.head()"
@@ -220,7 +340,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 44,
       "metadata": {
         "id": "Ndw4UtgWt_MJ"
       },
@@ -242,7 +362,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 45,
       "metadata": {
         "id": "0qDZ86qTr0zX"
       },
@@ -367,7 +487,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 46,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/",
@@ -422,7 +542,15 @@
         "id": "eGKU1J9Ar0zY",
         "outputId": "0a5f7fe5-7b5e-4c11-8a6e-7e85e8478b92"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Loading Bert Tokenizer...\n"
+          ]
+        }
+      ],
       "source": [
         "if model_name == 'bert-base-multilingual-cased' :\n",
         "    print('Loading Bert Tokenizer...')\n",
@@ -446,7 +574,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 47,
       "metadata": {
         "id": "CN8EZst-r0zZ"
       },
@@ -479,9 +607,18 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 48,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/text_classification.py:89: UserWarning: `return_all_scores` is now deprecated,  if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n",
+            "  warnings.warn(\n"
+          ]
+        }
+      ],
       "source": [
         "# https://huggingface.co/docs/transformers/main_classes/pipelines\n",
         "\n",
@@ -497,19 +634,67 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 49,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "13 0.9375858902931213 2 0.021192006766796112 7 0.012656938284635544\n",
+            "6 0.9926056861877441 7 0.0029343003407120705 8 0.0010190330212935805\n",
+            "13 0.9823671579360962 2 0.00412388751283288 1 0.0022031590342521667\n",
+            "10 0.9058945775032043 2 0.029459038749337196 7 0.014979560859501362\n",
+            "7 0.9861114025115967 2 0.003949115984141827 6 0.0015271392185240984\n",
+            "4 0.9868664741516113 5 0.002140316180884838 15 0.0018120049498975277\n",
+            "6 0.9541037678718567 7 0.025117166340351105 8 0.00887206755578518\n",
+            "6 0.9981995820999146 7 0.00028012823895551264 8 0.00019026087829843163\n",
+            "6 0.9958584904670715 8 0.0010782132158055902 7 0.000548136536963284\n",
+            "6 0.9979164004325867 7 0.0005610007210634649 9 0.00018632493447512388\n",
+            "6 0.997787356376648 7 0.0003991609555669129 8 0.0002408416330581531\n",
+            "6 0.9979755282402039 7 0.0005005390848964453 8 0.0002189433143939823\n",
+            "11 0.9915592074394226 3 0.00250804889947176 14 0.0010435001458972692\n",
+            "7 0.958525538444519 2 0.011816944926977158 5 0.009215029887855053\n",
+            "8 0.2876076102256775 7 0.2462710738182068 2 0.17002692818641663\n",
+            "8 0.9409826397895813 7 0.03510138392448425 6 0.007040794938802719\n",
+            "8 0.3623795211315155 1 0.3142264485359192 7 0.13734686374664307\n",
+            "7 0.7184596061706543 6 0.11600398272275925 8 0.09759759902954102\n",
+            "7 0.8406069278717041 6 0.12032385170459747 2 0.009349718689918518\n",
+            "7 0.978775143623352 2 0.005065936129540205 4 0.0037283776327967644\n",
+            "6 0.4818583130836487 9 0.22724471986293793 5 0.07886118441820145\n",
+            "6 0.9740952253341675 8 0.015889622271060944 1 0.001933401683345437\n"
+          ]
+        },
+        {
+          "ename": "KeyboardInterrupt",
+          "evalue": "",
+          "output_type": "error",
+          "traceback": [
+            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
+            "\u001b[0;32m/var/folders/qm/v_b1md29221_cnpcxf5qc43c0000gn/T/ipykernel_15176/3568789409.py\u001b[0m in \u001b[0;36m<cell line: 3>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mout\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpipe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mtokenizer_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m     \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreverse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# label ### TODO modifier ici\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/pt_utils.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    113\u001b[0m         \u001b[0;31m# We're out of items within a batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m         \u001b[0mitem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    115\u001b[0m         \u001b[0mprocessed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m         \u001b[0;31m# We now have a batch of \"inferred things\".\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/pt_utils.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    113\u001b[0m         \u001b[0;31m# We're out of items within a batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    114\u001b[0m         \u001b[0mitem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m         \u001b[0mprocessed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    116\u001b[0m         \u001b[0;31m# We now have a batch of \"inferred things\".\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    117\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader_batch_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/base.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, model_inputs, **forward_params)\u001b[0m\n\u001b[1;32m    988\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0minference_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    989\u001b[0m                     \u001b[0mmodel_inputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ensure_tensor_on_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 990\u001b[0;31m                     \u001b[0mmodel_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mforward_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    991\u001b[0m                     \u001b[0mmodel_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ensure_tensor_on_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    992\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/text_classification.py\u001b[0m in \u001b[0;36m_forward\u001b[0;34m(self, model_inputs)\u001b[0m\n\u001b[1;32m    165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    166\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_inputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    169\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mpostprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunction_to_apply\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtop_k\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_legacy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1188\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1191\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1192\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1550\u001b[0m         \u001b[0mreturn_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_return_dict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1551\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1552\u001b[0;31m         outputs = self.bert(\n\u001b[0m\u001b[1;32m   1553\u001b[0m             \u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1554\u001b[0m             \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1188\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1191\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1192\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    985\u001b[0m         \u001b[0;31m# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    986\u001b[0m         \u001b[0;31m# ourselves in which case we just need to make it broadcastable to all heads.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 987\u001b[0;31m         \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_extended_attention_mask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    988\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    989\u001b[0m         \u001b[0;31m# If a 2D or 3D attention mask is provided for the cross-attention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/modeling_utils.py\u001b[0m in \u001b[0;36mget_extended_attention_mask\u001b[0;34m(self, attention_mask, input_shape, device, dtype)\u001b[0m\n\u001b[1;32m    789\u001b[0m         \u001b[0;31m# effectively the same as removing these entirely.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    790\u001b[0m         \u001b[0mextended_attention_mask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# fp16 compatibility\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 791\u001b[0;31m         \u001b[0mextended_attention_mask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1.0\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    792\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    793\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     37\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mhas_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     38\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwrapped\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 39\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     40\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     41\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mNotImplemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36m__rsub__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m    831\u001b[0m     \u001b[0;34m@\u001b[0m\u001b[0m_handle_torch_function_and_wrap_type_error_to_not_implemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    832\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__rsub__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 833\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_VariableFunctions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrsub\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    834\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    835\u001b[0m     \u001b[0;34m@\u001b[0m\u001b[0m_handle_torch_function_and_wrap_type_error_to_not_implemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+          ]
+        }
+      ],
       "source": [
         "pred = []\n",
-        "cpt = 0\n",
+        "\n",
         "for out in pipe(data(), **tokenizer_kwargs):\n",
         "    out = sorted(out, key=lambda d: d['score'], reverse=True) \n",
-        "    print(int(out[0]['label'][6:]), out[0]['score'], int(out[1]['label'][6:]), out[1]['score'], int(out[2]['label'][6:]), out[2]['score']) # label ### TODO modifier ici\n",
+        "    #print(int(out[0]['label'][6:]), out[0]['score'], int(out[1]['label'][6:]), out[1]['score'], int(out[2]['label'][6:]), out[2]['score']) # label ### TODO modifier ici\n",
         "    pred.append([int(out[0]['label'][6:]), out[0]['score'], int(out[1]['label'][6:]), out[1]['score'], int(out[2]['label'][6:]), out[2]['score']])\n",
-        "    cpt += 1\n",
-        "    if cpt == 6:\n",
-        "        break\n",
         "\n",
         "pred = np.array(pred)"
       ]
-- 
GitLab