From 4731074b4a508b0da5fbf83704e845214b6e2123 Mon Sep 17 00:00:00 2001 From: Ludovic Moncla <moncla.ludovic@gmail.com> Date: Tue, 21 Mar 2023 13:18:00 +0100 Subject: [PATCH] Update Predict.ipynb --- notebooks/Predict.ipynb | 225 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 205 insertions(+), 20 deletions(-) diff --git a/notebooks/Predict.ipynb b/notebooks/Predict.ipynb index 0dcd667..ac3cd1d 100644 --- a/notebooks/Predict.ipynb +++ b/notebooks/Predict.ipynb @@ -98,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -106,7 +106,15 @@ "id": "dPOU-Efhf4ui", "outputId": "0bb7fd0e-e2fb-4477-e5f7-b408d0a1ced7" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "We will use the GPU\n" + ] + } + ], "source": [ "import torch\n", "\n", @@ -138,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 41, "metadata": { "id": "SkErnwgMMbRj" }, @@ -181,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 42, "metadata": { "id": "M2awiee1r0zV" }, @@ -203,7 +211,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 43, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -212,7 +220,119 @@ "id": "erjPU3y8r0zW", "outputId": "e2b4a39d-a72b-4e7a-8b26-e709eb983df3" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>uid</th>\n", + " <th>lge-volume</th>\n", + " <th>lge-numero</th>\n", + " <th>lge-head</th>\n", + " <th>lge-page</th>\n", + " <th>lge-id</th>\n", + " <th>lge-content</th>\n", + " <th>lge-nbWords</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>lge_1_a-0</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>A</td>\n", + " <td>0</td>\n", + " <td>a-0</td>\n", + " <td>A(Ling.). Son vocal et première lettre de notr...</td>\n", + " <td>1761.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>lge_1_a-1</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>A</td>\n", + " <td>1</td>\n", + " <td>a-1</td>\n", + " <td>A(Paléogr.). C’est à l’alphabet phénicien, on ...</td>\n", + " <td>839.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>lge_1_a-2</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>A</td>\n", + " <td>4</td>\n", + " <td>a-2</td>\n", + " <td>A(Log.). Cette voyelle désigne les proposition...</td>\n", + " <td>56.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>lge_1_a-3</td>\n", + " <td>1</td>\n", + " <td>4</td>\n", + " <td>A</td>\n", + " <td>4</td>\n", + " <td>a-3</td>\n", + " <td>A(Mus.). La lettre a est employée par les musi...</td>\n", + " <td>267.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>lge_1_a-4</td>\n", + " <td>1</td>\n", + " <td>5</td>\n", + " <td>A</td>\n", + " <td>4</td>\n", + " <td>a-4</td>\n", + " <td>A(Numis.). Dans la numismatique grecque, la le...</td>\n", + " <td>67.0</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " uid lge-volume lge-numero lge-head lge-page lge-id \\\n", + "0 lge_1_a-0 1 1 A 0 a-0 \n", + "1 lge_1_a-1 1 2 A 1 a-1 \n", + "2 lge_1_a-2 1 3 A 4 a-2 \n", + "3 lge_1_a-3 1 4 A 4 a-3 \n", + "4 lge_1_a-4 1 5 A 4 a-4 \n", + "\n", + " lge-content lge-nbWords \n", + "0 A(Ling.). Son vocal et première lettre de notr... 1761.0 \n", + "1 A(Paléogr.). C’est à l’alphabet phénicien, on ... 839.0 \n", + "2 A(Log.). Cette voyelle désigne les proposition... 56.0 \n", + "3 A(Mus.). La lettre a est employée par les musi... 267.0 \n", + "4 A(Numis.). Dans la numismatique grecque, la le... 67.0 " + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "df = pd.read_csv(path + filepath, sep=\"\\t\")\n", "df.head()" @@ -220,7 +340,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 44, "metadata": { "id": "Ndw4UtgWt_MJ" }, @@ -242,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 45, "metadata": { "id": "0qDZ86qTr0zX" }, @@ -367,7 +487,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 46, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -422,7 +542,15 @@ "id": "eGKU1J9Ar0zY", "outputId": "0a5f7fe5-7b5e-4c11-8a6e-7e85e8478b92" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading Bert Tokenizer...\n" + ] + } + ], "source": [ "if model_name == 'bert-base-multilingual-cased' :\n", " print('Loading Bert Tokenizer...')\n", @@ -446,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 47, "metadata": { "id": "CN8EZst-r0zZ" }, @@ -479,9 +607,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 48, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/text_classification.py:89: UserWarning: `return_all_scores` is now deprecated, if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "# https://huggingface.co/docs/transformers/main_classes/pipelines\n", "\n", @@ -497,19 +634,67 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 49, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13 0.9375858902931213 2 0.021192006766796112 7 0.012656938284635544\n", + "6 0.9926056861877441 7 0.0029343003407120705 8 0.0010190330212935805\n", + "13 0.9823671579360962 2 0.00412388751283288 1 0.0022031590342521667\n", + "10 0.9058945775032043 2 0.029459038749337196 7 0.014979560859501362\n", + "7 0.9861114025115967 2 0.003949115984141827 6 0.0015271392185240984\n", + "4 0.9868664741516113 5 0.002140316180884838 15 0.0018120049498975277\n", + "6 0.9541037678718567 7 0.025117166340351105 8 0.00887206755578518\n", + "6 0.9981995820999146 7 0.00028012823895551264 8 0.00019026087829843163\n", + "6 0.9958584904670715 8 0.0010782132158055902 7 0.000548136536963284\n", + "6 0.9979164004325867 7 0.0005610007210634649 9 0.00018632493447512388\n", + "6 0.997787356376648 7 0.0003991609555669129 8 0.0002408416330581531\n", + "6 0.9979755282402039 7 0.0005005390848964453 8 0.0002189433143939823\n", + "11 0.9915592074394226 3 0.00250804889947176 14 0.0010435001458972692\n", + "7 0.958525538444519 2 0.011816944926977158 5 0.009215029887855053\n", + "8 0.2876076102256775 7 0.2462710738182068 2 0.17002692818641663\n", + "8 0.9409826397895813 7 0.03510138392448425 6 0.007040794938802719\n", + "8 0.3623795211315155 1 0.3142264485359192 7 0.13734686374664307\n", + "7 0.7184596061706543 6 0.11600398272275925 8 0.09759759902954102\n", + "7 0.8406069278717041 6 0.12032385170459747 2 0.009349718689918518\n", + "7 0.978775143623352 2 0.005065936129540205 4 0.0037283776327967644\n", + "6 0.4818583130836487 9 0.22724471986293793 5 0.07886118441820145\n", + "6 0.9740952253341675 8 0.015889622271060944 1 0.001933401683345437\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/folders/qm/v_b1md29221_cnpcxf5qc43c0000gn/T/ipykernel_15176/3568789409.py\u001b[0m in \u001b[0;36m<cell line: 3>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mout\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpipe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mtokenizer_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreverse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# label ### TODO modifier ici\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/pt_utils.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;31m# We're out of items within a batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mitem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0mprocessed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;31m# We now have a batch of \"inferred things\".\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/pt_utils.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;31m# We're out of items within a batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m \u001b[0mprocessed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 116\u001b[0m \u001b[0;31m# We now have a batch of \"inferred things\".\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader_batch_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/base.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, model_inputs, **forward_params)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0minference_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0mmodel_inputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ensure_tensor_on_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 990\u001b[0;31m \u001b[0mmodel_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mforward_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 991\u001b[0m \u001b[0mmodel_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ensure_tensor_on_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 992\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/text_classification.py\u001b[0m in \u001b[0;36m_forward\u001b[0;34m(self, model_inputs)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_inputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpostprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunction_to_apply\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtop_k\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_legacy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1188\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1191\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1550\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_return_dict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1552\u001b[0;31m outputs = self.bert(\n\u001b[0m\u001b[1;32m 1553\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1554\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1188\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1191\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 985\u001b[0m \u001b[0;31m# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 986\u001b[0m \u001b[0;31m# ourselves in which case we just need to make it broadcastable to all heads.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 987\u001b[0;31m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_extended_attention_mask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 988\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0;31m# If a 2D or 3D attention mask is provided for the cross-attention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/modeling_utils.py\u001b[0m in \u001b[0;36mget_extended_attention_mask\u001b[0;34m(self, attention_mask, input_shape, device, dtype)\u001b[0m\n\u001b[1;32m 789\u001b[0m \u001b[0;31m# effectively the same as removing these entirely.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[0mextended_attention_mask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# fp16 compatibility\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 791\u001b[0;31m \u001b[0mextended_attention_mask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1.0\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 792\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 793\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhas_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwrapped\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 39\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 40\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mNotImplemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36m__rsub__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 831\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_handle_torch_function_and_wrap_type_error_to_not_implemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 832\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__rsub__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 833\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_VariableFunctions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrsub\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 834\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 835\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_handle_torch_function_and_wrap_type_error_to_not_implemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "source": [ "pred = []\n", - "cpt = 0\n", + "\n", "for out in pipe(data(), **tokenizer_kwargs):\n", " out = sorted(out, key=lambda d: d['score'], reverse=True) \n", - " print(int(out[0]['label'][6:]), out[0]['score'], int(out[1]['label'][6:]), out[1]['score'], int(out[2]['label'][6:]), out[2]['score']) # label ### TODO modifier ici\n", + " #print(int(out[0]['label'][6:]), out[0]['score'], int(out[1]['label'][6:]), out[1]['score'], int(out[2]['label'][6:]), out[2]['score']) # label ### TODO modifier ici\n", " pred.append([int(out[0]['label'][6:]), out[0]['score'], int(out[1]['label'][6:]), out[1]['score'], int(out[2]['label'][6:]), out[2]['score']])\n", - " cpt += 1\n", - " if cpt == 6:\n", - " break\n", "\n", "pred = np.array(pred)" ] -- GitLab