diff --git a/notebooks/Predict.ipynb b/notebooks/Predict.ipynb index ac3cd1d5c76ea6bccb4edfe94ea3851e81a5499d..1ea2070ad60244e6583b5d8c2da8849c4d135d62 100644 --- a/notebooks/Predict.ipynb +++ b/notebooks/Predict.ipynb @@ -146,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 51, "metadata": { "id": "SkErnwgMMbRj" }, @@ -154,6 +154,7 @@ "source": [ "import pandas as pd \n", "import numpy as np\n", + "from tqdm import tqdm\n", "\n", "from transformers import BertTokenizer, BertForSequenceClassification, CamembertTokenizer, TextClassificationPipeline\n", "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n", @@ -634,64 +635,21 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 52, "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "13 0.9375858902931213 2 0.021192006766796112 7 0.012656938284635544\n", - "6 0.9926056861877441 7 0.0029343003407120705 8 0.0010190330212935805\n", - "13 0.9823671579360962 2 0.00412388751283288 1 0.0022031590342521667\n", - "10 0.9058945775032043 2 0.029459038749337196 7 0.014979560859501362\n", - "7 0.9861114025115967 2 0.003949115984141827 6 0.0015271392185240984\n", - "4 0.9868664741516113 5 0.002140316180884838 15 0.0018120049498975277\n", - "6 0.9541037678718567 7 0.025117166340351105 8 0.00887206755578518\n", - "6 0.9981995820999146 7 0.00028012823895551264 8 0.00019026087829843163\n", - "6 0.9958584904670715 8 0.0010782132158055902 7 0.000548136536963284\n", - "6 0.9979164004325867 7 0.0005610007210634649 9 0.00018632493447512388\n", - "6 0.997787356376648 7 0.0003991609555669129 8 0.0002408416330581531\n", - "6 0.9979755282402039 7 0.0005005390848964453 8 0.0002189433143939823\n", - "11 0.9915592074394226 3 0.00250804889947176 14 0.0010435001458972692\n", - "7 0.958525538444519 2 0.011816944926977158 5 0.009215029887855053\n", - "8 0.2876076102256775 7 0.2462710738182068 2 0.17002692818641663\n", - "8 0.9409826397895813 7 0.03510138392448425 6 0.007040794938802719\n", - "8 0.3623795211315155 1 0.3142264485359192 7 0.13734686374664307\n", - "7 0.7184596061706543 6 0.11600398272275925 8 0.09759759902954102\n", - "7 0.8406069278717041 6 0.12032385170459747 2 0.009349718689918518\n", - "7 0.978775143623352 2 0.005065936129540205 4 0.0037283776327967644\n", - "6 0.4818583130836487 9 0.22724471986293793 5 0.07886118441820145\n", - "6 0.9740952253341675 8 0.015889622271060944 1 0.001933401683345437\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/var/folders/qm/v_b1md29221_cnpcxf5qc43c0000gn/T/ipykernel_15176/3568789409.py\u001b[0m in \u001b[0;36m<cell line: 3>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mout\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpipe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mtokenizer_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreverse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# label ### TODO modifier ici\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/pt_utils.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;31m# We're out of items within a batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mitem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0mprocessed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;31m# We now have a batch of \"inferred things\".\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/pt_utils.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;31m# We're out of items within a batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m \u001b[0mprocessed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 116\u001b[0m \u001b[0;31m# We now have a batch of \"inferred things\".\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader_batch_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/base.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, model_inputs, **forward_params)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0minference_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0mmodel_inputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ensure_tensor_on_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 990\u001b[0;31m \u001b[0mmodel_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mforward_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 991\u001b[0m \u001b[0mmodel_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ensure_tensor_on_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 992\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/text_classification.py\u001b[0m in \u001b[0;36m_forward\u001b[0;34m(self, model_inputs)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_inputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpostprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunction_to_apply\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtop_k\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_legacy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1188\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1191\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1550\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_return_dict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1552\u001b[0;31m outputs = self.bert(\n\u001b[0m\u001b[1;32m 1553\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1554\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1188\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1191\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 985\u001b[0m \u001b[0;31m# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 986\u001b[0m \u001b[0;31m# ourselves in which case we just need to make it broadcastable to all heads.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 987\u001b[0;31m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_extended_attention_mask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 988\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0;31m# If a 2D or 3D attention mask is provided for the cross-attention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/modeling_utils.py\u001b[0m in \u001b[0;36mget_extended_attention_mask\u001b[0;34m(self, attention_mask, input_shape, device, dtype)\u001b[0m\n\u001b[1;32m 789\u001b[0m \u001b[0;31m# effectively the same as removing these entirely.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[0mextended_attention_mask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# fp16 compatibility\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 791\u001b[0;31m \u001b[0mextended_attention_mask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1.0\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 792\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 793\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhas_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwrapped\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 39\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 40\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mNotImplemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36m__rsub__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 831\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_handle_torch_function_and_wrap_type_error_to_not_implemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 832\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__rsub__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 833\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_VariableFunctions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrsub\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 834\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 835\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_handle_torch_function_and_wrap_type_error_to_not_implemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "134820it [1:07:31, 33.27it/s]\n" ] } ], "source": [ "pred = []\n", "\n", - "for out in pipe(data(), **tokenizer_kwargs):\n", + "for out in tqdm(pipe(data(), **tokenizer_kwargs)):\n", " out = sorted(out, key=lambda d: d['score'], reverse=True) \n", " #print(int(out[0]['label'][6:]), out[0]['score'], int(out[1]['label'][6:]), out[1]['score'], int(out[2]['label'][6:]), out[2]['score']) # label ### TODO modifier ici\n", " pred.append([int(out[0]['label'][6:]), out[0]['score'], int(out[1]['label'][6:]), out[1]['score'], int(out[2]['label'][6:]), out[2]['score']])\n", @@ -701,11 +659,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 53, "metadata": { "id": "fo6k4li1r0za" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LabelEncoder from version 1.0.2 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", + "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", + " warnings.warn(\n" + ] + } + ], "source": [ "# Load label encoder\n", "\n", @@ -717,7 +685,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 54, "metadata": { "id": "UU7qg7zVr0zb" }, @@ -730,7 +698,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ @@ -740,7 +708,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 56, "metadata": { "id": "w4eHpBztr0zb" }, @@ -756,7 +724,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 57, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -765,14 +733,281 @@ "id": "OCy54lRLr0zb", "outputId": "a42d8a75-48b9-431a-9b8e-71e4d7018c6b" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>uid</th>\n", + " <th>lge-volume</th>\n", + " <th>lge-numero</th>\n", + " <th>lge-head</th>\n", + " <th>lge-page</th>\n", + " <th>lge-id</th>\n", + " <th>lge-content</th>\n", + " <th>lge-nbWords</th>\n", + " <th>lge-superdomainPred1</th>\n", + " <th>lge-superdomainProba1</th>\n", + " <th>lge-superdomainPred2</th>\n", + " <th>lge-superdomainProba2</th>\n", + " <th>lge-superdomainPred3</th>\n", + " <th>lge-superdomainProba3</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>lge_1_a-0</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>A</td>\n", + " <td>0</td>\n", + " <td>a-0</td>\n", + " <td>A(Ling.). Son vocal et première lettre de notr...</td>\n", + " <td>1761.0</td>\n", + " <td>Philosophie</td>\n", + " <td>0.937586</td>\n", + " <td>Belles-lettres</td>\n", + " <td>0.021192</td>\n", + " <td>Histoire</td>\n", + " <td>0.012657</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>lge_1_a-1</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>A</td>\n", + " <td>1</td>\n", + " <td>a-1</td>\n", + " <td>A(Paléogr.). C’est à l’alphabet phénicien, on ...</td>\n", + " <td>839.0</td>\n", + " <td>Géographie</td>\n", + " <td>0.992606</td>\n", + " <td>Histoire</td>\n", + " <td>0.002934</td>\n", + " <td>Histoire naturelle</td>\n", + " <td>0.001019</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>lge_1_a-2</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>A</td>\n", + " <td>4</td>\n", + " <td>a-2</td>\n", + " <td>A(Log.). Cette voyelle désigne les proposition...</td>\n", + " <td>56.0</td>\n", + " <td>Philosophie</td>\n", + " <td>0.982367</td>\n", + " <td>Belles-lettres</td>\n", + " <td>0.004124</td>\n", + " <td>Beaux-arts</td>\n", + " <td>0.002203</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>lge_1_a-3</td>\n", + " <td>1</td>\n", + " <td>4</td>\n", + " <td>A</td>\n", + " <td>4</td>\n", + " <td>a-3</td>\n", + " <td>A(Mus.). La lettre a est employée par les musi...</td>\n", + " <td>267.0</td>\n", + " <td>Musique</td>\n", + " <td>0.905895</td>\n", + " <td>Belles-lettres</td>\n", + " <td>0.029459</td>\n", + " <td>Histoire</td>\n", + " <td>0.014980</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>lge_1_a-4</td>\n", + " <td>1</td>\n", + " <td>5</td>\n", + " <td>A</td>\n", + " <td>4</td>\n", + " <td>a-4</td>\n", + " <td>A(Numis.). Dans la numismatique grecque, la le...</td>\n", + " <td>67.0</td>\n", + " <td>Histoire</td>\n", + " <td>0.986111</td>\n", + " <td>Belles-lettres</td>\n", + " <td>0.003949</td>\n", + " <td>Géographie</td>\n", + " <td>0.001527</td>\n", + " </tr>\n", + " <tr>\n", + " <th>5</th>\n", + " <td>lge_1_aa-0</td>\n", + " <td>1</td>\n", + " <td>6</td>\n", + " <td>AA</td>\n", + " <td>4</td>\n", + " <td>aa-0</td>\n", + " <td>AA. Ces deux lettres désignent l’atelier monét...</td>\n", + " <td>14.0</td>\n", + " <td>Commerce</td>\n", + " <td>0.986866</td>\n", + " <td>Droit Jurisprudence</td>\n", + " <td>0.002140</td>\n", + " <td>Politique</td>\n", + " <td>0.001812</td>\n", + " </tr>\n", + " <tr>\n", + " <th>6</th>\n", + " <td>lge_1_aa-1</td>\n", + " <td>1</td>\n", + " <td>7</td>\n", + " <td>AA</td>\n", + " <td>4</td>\n", + " <td>aa-1</td>\n", + " <td>AA. Nom de plusieurs cours d’eau de l’Europe o...</td>\n", + " <td>75.0</td>\n", + " <td>Géographie</td>\n", + " <td>0.954104</td>\n", + " <td>Histoire</td>\n", + " <td>0.025117</td>\n", + " <td>Histoire naturelle</td>\n", + " <td>0.008872</td>\n", + " </tr>\n", + " <tr>\n", + " <th>7</th>\n", + " <td>lge_1_aa-2</td>\n", + " <td>1</td>\n", + " <td>8</td>\n", + " <td>AA</td>\n", + " <td>5</td>\n", + " <td>aa-2</td>\n", + " <td>AA. Rivière de France, prend sa source aux Tro...</td>\n", + " <td>165.0</td>\n", + " <td>Géographie</td>\n", + " <td>0.998200</td>\n", + " <td>Histoire</td>\n", + " <td>0.000280</td>\n", + " <td>Histoire naturelle</td>\n", + " <td>0.000190</td>\n", + " </tr>\n", + " <tr>\n", + " <th>8</th>\n", + " <td>lge_1_aa-3</td>\n", + " <td>1</td>\n", + " <td>9</td>\n", + " <td>AA</td>\n", + " <td>5</td>\n", + " <td>aa-3</td>\n", + " <td>AA. Rivière de Hollande, affluent de la Dommel...</td>\n", + " <td>17.0</td>\n", + " <td>Géographie</td>\n", + " <td>0.995858</td>\n", + " <td>Histoire naturelle</td>\n", + " <td>0.001078</td>\n", + " <td>Histoire</td>\n", + " <td>0.000548</td>\n", + " </tr>\n", + " <tr>\n", + " <th>9</th>\n", + " <td>lge_1_aa-4</td>\n", + " <td>1</td>\n", + " <td>10</td>\n", + " <td>AA</td>\n", + " <td>5</td>\n", + " <td>aa-4</td>\n", + " <td>AA. Nom de deux fleuves de la Russie. Le premi...</td>\n", + " <td>71.0</td>\n", + " <td>Géographie</td>\n", + " <td>0.997916</td>\n", + " <td>Histoire</td>\n", + " <td>0.000561</td>\n", + " <td>Militaire</td>\n", + " <td>0.000186</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " uid lge-volume lge-numero lge-head lge-page lge-id \\\n", + "0 lge_1_a-0 1 1 A 0 a-0 \n", + "1 lge_1_a-1 1 2 A 1 a-1 \n", + "2 lge_1_a-2 1 3 A 4 a-2 \n", + "3 lge_1_a-3 1 4 A 4 a-3 \n", + "4 lge_1_a-4 1 5 A 4 a-4 \n", + "5 lge_1_aa-0 1 6 AA 4 aa-0 \n", + "6 lge_1_aa-1 1 7 AA 4 aa-1 \n", + "7 lge_1_aa-2 1 8 AA 5 aa-2 \n", + "8 lge_1_aa-3 1 9 AA 5 aa-3 \n", + "9 lge_1_aa-4 1 10 AA 5 aa-4 \n", + "\n", + " lge-content lge-nbWords \\\n", + "0 A(Ling.). Son vocal et première lettre de notr... 1761.0 \n", + "1 A(Paléogr.). C’est à l’alphabet phénicien, on ... 839.0 \n", + "2 A(Log.). Cette voyelle désigne les proposition... 56.0 \n", + "3 A(Mus.). La lettre a est employée par les musi... 267.0 \n", + "4 A(Numis.). Dans la numismatique grecque, la le... 67.0 \n", + "5 AA. Ces deux lettres désignent l’atelier monét... 14.0 \n", + "6 AA. Nom de plusieurs cours d’eau de l’Europe o... 75.0 \n", + "7 AA. Rivière de France, prend sa source aux Tro... 165.0 \n", + "8 AA. Rivière de Hollande, affluent de la Dommel... 17.0 \n", + "9 AA. Nom de deux fleuves de la Russie. Le premi... 71.0 \n", + "\n", + " lge-superdomainPred1 lge-superdomainProba1 lge-superdomainPred2 \\\n", + "0 Philosophie 0.937586 Belles-lettres \n", + "1 Géographie 0.992606 Histoire \n", + "2 Philosophie 0.982367 Belles-lettres \n", + "3 Musique 0.905895 Belles-lettres \n", + "4 Histoire 0.986111 Belles-lettres \n", + "5 Commerce 0.986866 Droit Jurisprudence \n", + "6 Géographie 0.954104 Histoire \n", + "7 Géographie 0.998200 Histoire \n", + "8 Géographie 0.995858 Histoire naturelle \n", + "9 Géographie 0.997916 Histoire \n", + "\n", + " lge-superdomainProba2 lge-superdomainPred3 lge-superdomainProba3 \n", + "0 0.021192 Histoire 0.012657 \n", + "1 0.002934 Histoire naturelle 0.001019 \n", + "2 0.004124 Beaux-arts 0.002203 \n", + "3 0.029459 Histoire 0.014980 \n", + "4 0.003949 Géographie 0.001527 \n", + "5 0.002140 Politique 0.001812 \n", + "6 0.025117 Histoire naturelle 0.008872 \n", + "7 0.000280 Histoire naturelle 0.000190 \n", + "8 0.001078 Histoire 0.000548 \n", + "9 0.000561 Militaire 0.000186 " + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "df.head(10)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 58, "metadata": { "id": "J9rObbvVr0zc" }, @@ -784,31 +1019,84 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8cX6XBq8_F5T" - }, - "outputs": [], - "source": [ - "#df.drop(columns=['contentLGE', 'contentEDdA'], inplace=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 59, "metadata": { "id": "7TD1mbKj_fXH" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>uid</th>\n", + " <th>lge-volume</th>\n", + " <th>lge-numero</th>\n", + " <th>lge-head</th>\n", + " <th>lge-page</th>\n", + " <th>lge-id</th>\n", + " <th>lge-content</th>\n", + " <th>lge-nbWords</th>\n", + " <th>lge-superdomainPred1</th>\n", + " <th>lge-superdomainProba1</th>\n", + " <th>lge-superdomainPred2</th>\n", + " <th>lge-superdomainProba2</th>\n", + " <th>lge-superdomainPred3</th>\n", + " <th>lge-superdomainProba3</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + "Empty DataFrame\n", + "Columns: [uid, lge-volume, lge-numero, lge-head, lge-page, lge-id, lge-content, lge-nbWords, lge-superdomainPred1, lge-superdomainProba1, lge-superdomainPred2, lge-superdomainProba2, lge-superdomainPred3, lge-superdomainProba3]\n", + "Index: []" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "df.loc[(df[corpus+'-superdomainProba1'] == 'Géographie')]" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 60, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(134820, 14)" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "df.shape" ]