diff --git a/notebooks/Predict.ipynb b/notebooks/Predict.ipynb index f479b3e92a836387e0e59adc5e671e6e587344be..76e7e0ac071aed40627043e8e26de73460c8d9b4 100644 --- a/notebooks/Predict.ipynb +++ b/notebooks/Predict.ipynb @@ -139,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 2, "metadata": { "id": "SkErnwgMMbRj" }, @@ -330,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -341,7 +341,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -445,7 +445,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -467,7 +467,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -494,22 +494,9 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 18, "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "Expected state_dict to be dict-like, got <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'>.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn [26], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m#model = torch.load(model_path, map_location=torch.device('mps'))\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m model\u001b[39m.\u001b[39;49mload_state_dict(torch\u001b[39m.\u001b[39;49mload(model_path, map_location\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mdevice(\u001b[39m'\u001b[39;49m\u001b[39mmps\u001b[39;49m\u001b[39m'\u001b[39;49m)))\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1620\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m 1597\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"Copies parameters and buffers from :attr:`state_dict` into\u001b[39;00m\n\u001b[1;32m 1598\u001b[0m \u001b[39mthis module and its descendants. If :attr:`strict` is ``True``, then\u001b[39;00m\n\u001b[1;32m 1599\u001b[0m \u001b[39mthe keys of :attr:`state_dict` must exactly match the keys returned\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1617\u001b[0m \u001b[39m ``RuntimeError``.\u001b[39;00m\n\u001b[1;32m 1618\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 1619\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(state_dict, Mapping):\n\u001b[0;32m-> 1620\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mExpected state_dict to be dict-like, got \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\u001b[39mtype\u001b[39m(state_dict)))\n\u001b[1;32m 1622\u001b[0m missing_keys: List[\u001b[39mstr\u001b[39m] \u001b[39m=\u001b[39m []\n\u001b[1;32m 1623\u001b[0m unexpected_keys: List[\u001b[39mstr\u001b[39m] \u001b[39m=\u001b[39m []\n", - "\u001b[0;31mTypeError\u001b[0m: Expected state_dict to be dict-like, got <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'>." - ] - } - ], + "outputs": [], "source": [ "#model = torch.load(model_path, map_location=torch.device('mps'))\n", "#model.load_state_dict(torch.load(model_path, map_location=torch.device('mps')))\n", @@ -519,7 +506,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -527,65 +514,493 @@ "id": "_fzgS5USJeAF", "outputId": "be4a5506-76ed-4eef-bb3c-fe2bb77c6e4d" }, + "outputs": [], + "source": [ + "pred = predict(model, data_loader, device)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, "outputs": [ { - "ename": "AttributeError", - "evalue": "'BertEncoder' object has no attribute 'gradient_checkpointing'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn [14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m pred \u001b[39m=\u001b[39m predict(model, data_loader, device)\n", - "Cell \u001b[0;32mIn [8], line 68\u001b[0m, in \u001b[0;36mpredict\u001b[0;34m(model, dataloader, device)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[39m# Telling the model not to compute or store gradients, saving memory and\u001b[39;00m\n\u001b[1;32m 65\u001b[0m \u001b[39m# speeding up prediction\u001b[39;00m\n\u001b[1;32m 66\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[1;32m 67\u001b[0m \u001b[39m# Forward pass, calculate logit predictions\u001b[39;00m\n\u001b[0;32m---> 68\u001b[0m outputs \u001b[39m=\u001b[39m model(b_input_ids, token_type_ids\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 69\u001b[0m attention_mask\u001b[39m=\u001b[39;49mb_input_mask)\n\u001b[1;32m 71\u001b[0m logits \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 72\u001b[0m \u001b[39m#print(logits)\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \n\u001b[1;32m 74\u001b[0m \u001b[39m# Move logits and labels to CPU ???\u001b[39;00m\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1191\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:1552\u001b[0m, in \u001b[0;36mBertForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1544\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 1545\u001b[0m \u001b[39mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m 1546\u001b[0m \u001b[39m Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[1;32m 1547\u001b[0m \u001b[39m config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\u001b[39;00m\n\u001b[1;32m 1548\u001b[0m \u001b[39m `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\u001b[39;00m\n\u001b[1;32m 1549\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 1550\u001b[0m return_dict \u001b[39m=\u001b[39m return_dict \u001b[39mif\u001b[39;00m return_dict \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39muse_return_dict\n\u001b[0;32m-> 1552\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbert(\n\u001b[1;32m 1553\u001b[0m input_ids,\n\u001b[1;32m 1554\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 1555\u001b[0m token_type_ids\u001b[39m=\u001b[39;49mtoken_type_ids,\n\u001b[1;32m 1556\u001b[0m position_ids\u001b[39m=\u001b[39;49mposition_ids,\n\u001b[1;32m 1557\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 1558\u001b[0m inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m 1559\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 1560\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 1561\u001b[0m return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m 1562\u001b[0m )\n\u001b[1;32m 1564\u001b[0m pooled_output \u001b[39m=\u001b[39m outputs[\u001b[39m1\u001b[39m]\n\u001b[1;32m 1566\u001b[0m pooled_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdropout(pooled_output)\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1191\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:1014\u001b[0m, in \u001b[0;36mBertModel.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1005\u001b[0m head_mask \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_head_mask(head_mask, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mnum_hidden_layers)\n\u001b[1;32m 1007\u001b[0m embedding_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings(\n\u001b[1;32m 1008\u001b[0m input_ids\u001b[39m=\u001b[39minput_ids,\n\u001b[1;32m 1009\u001b[0m position_ids\u001b[39m=\u001b[39mposition_ids,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1012\u001b[0m past_key_values_length\u001b[39m=\u001b[39mpast_key_values_length,\n\u001b[1;32m 1013\u001b[0m )\n\u001b[0;32m-> 1014\u001b[0m encoder_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mencoder(\n\u001b[1;32m 1015\u001b[0m embedding_output,\n\u001b[1;32m 1016\u001b[0m attention_mask\u001b[39m=\u001b[39;49mextended_attention_mask,\n\u001b[1;32m 1017\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 1018\u001b[0m encoder_hidden_states\u001b[39m=\u001b[39;49mencoder_hidden_states,\n\u001b[1;32m 1019\u001b[0m encoder_attention_mask\u001b[39m=\u001b[39;49mencoder_extended_attention_mask,\n\u001b[1;32m 1020\u001b[0m past_key_values\u001b[39m=\u001b[39;49mpast_key_values,\n\u001b[1;32m 1021\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 1022\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 1023\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 1024\u001b[0m return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m 1025\u001b[0m )\n\u001b[1;32m 1026\u001b[0m sequence_output \u001b[39m=\u001b[39m encoder_outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 1027\u001b[0m pooled_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpooler(sequence_output) \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpooler \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1191\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:580\u001b[0m, in \u001b[0;36mBertEncoder.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 577\u001b[0m layer_head_mask \u001b[39m=\u001b[39m head_mask[i] \u001b[39mif\u001b[39;00m head_mask \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 578\u001b[0m past_key_value \u001b[39m=\u001b[39m past_key_values[i] \u001b[39mif\u001b[39;00m past_key_values \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m--> 580\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgradient_checkpointing \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining:\n\u001b[1;32m 582\u001b[0m \u001b[39mif\u001b[39;00m use_cache:\n\u001b[1;32m 583\u001b[0m logger\u001b[39m.\u001b[39mwarning(\n\u001b[1;32m 584\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 585\u001b[0m )\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1265\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1263\u001b[0m \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m modules:\n\u001b[1;32m 1264\u001b[0m \u001b[39mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1265\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m object has no attribute \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 1266\u001b[0m \u001b[39mtype\u001b[39m(\u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, name))\n", - "\u001b[0;31mAttributeError\u001b[0m: 'BertEncoder' object has no attribute 'gradient_checkpointing'" - ] + "data": { + "text/plain": [ + "[15,\n", + " 6,\n", + " 16,\n", + " 15,\n", + " 17,\n", + " 10,\n", + " 17,\n", + " 16,\n", + " 19,\n", + " 35,\n", + " 15,\n", + " 26,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 2,\n", + " 2,\n", + " 17,\n", + " 6,\n", + " 32,\n", + " 17,\n", + " 30,\n", + " 16,\n", + " 32,\n", + " 15,\n", + " 35,\n", + " 15,\n", + " 23,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 17,\n", + " 15,\n", + " 16,\n", + " 3,\n", + " 17,\n", + " 17,\n", + " 16,\n", + " 4,\n", + " 15,\n", + " 17,\n", + " 19,\n", + " 16,\n", + " 35,\n", + " 3,\n", + " 17,\n", + " 5,\n", + " 15,\n", + " 16,\n", + " 16,\n", + " 15,\n", + " 16,\n", + " 6,\n", + " 16,\n", + " 5,\n", + " 16,\n", + " 15,\n", + " 28,\n", + " 16,\n", + " 17,\n", + " 10,\n", + " 15,\n", + " 15,\n", + " 32,\n", + " 15,\n", + " 17,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 12,\n", + " 15,\n", + " 18,\n", + " 15,\n", + " 35,\n", + " 26,\n", + " 16,\n", + " 16,\n", + " 15,\n", + " 5,\n", + " 15,\n", + " 15,\n", + " 5,\n", + " 17,\n", + " 15,\n", + " 17,\n", + " 35,\n", + " 15,\n", + " 16,\n", + " 16,\n", + " 17,\n", + " 2,\n", + " 17,\n", + " 15,\n", + " 16,\n", + " 23,\n", + " 16,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 16,\n", + " 6,\n", + " 15,\n", + " 35,\n", + " 15,\n", + " 32,\n", + " 16,\n", + " 6,\n", + " 16,\n", + " 23,\n", + " 36,\n", + " 5,\n", + " 35,\n", + " 3,\n", + " 3,\n", + " 3,\n", + " 16,\n", + " 17,\n", + " 2,\n", + " 15,\n", + " 5,\n", + " 17,\n", + " 16,\n", + " 15,\n", + " 17,\n", + " 6,\n", + " 15,\n", + " 16,\n", + " 10,\n", + " 16,\n", + " 15,\n", + " 35,\n", + " 17,\n", + " 15,\n", + " 15,\n", + " 6,\n", + " 28,\n", + " 16,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 16,\n", + " 5,\n", + " 15,\n", + " 21,\n", + " 5,\n", + " 1,\n", + " 7,\n", + " 16,\n", + " 15,\n", + " 17,\n", + " 23,\n", + " 15,\n", + " 5,\n", + " 0,\n", + " 10,\n", + " 16,\n", + " 16,\n", + " 15,\n", + " 16,\n", + " 15,\n", + " 15,\n", + " 3,\n", + " 3,\n", + " 17,\n", + " 36,\n", + " 16,\n", + " 15,\n", + " 12,\n", + " 6,\n", + " 15,\n", + " 4,\n", + " 16,\n", + " 16,\n", + " 26,\n", + " 15,\n", + " 15,\n", + " 32,\n", + " 15,\n", + " 10,\n", + " 15,\n", + " 5,\n", + " 26,\n", + " 5,\n", + " 15,\n", + " 15,\n", + " 26,\n", + " 15,\n", + " 35,\n", + " 15,\n", + " 16,\n", + " 16,\n", + " 15,\n", + " 6,\n", + " 16,\n", + " 12,\n", + " 16,\n", + " 28,\n", + " 16,\n", + " 15,\n", + " 15,\n", + " 16,\n", + " 6,\n", + " 10,\n", + " 16,\n", + " 15,\n", + " 15,\n", + " 16,\n", + " 16,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 5,\n", + " 16,\n", + " 16,\n", + " 17,\n", + " 15,\n", + " 16,\n", + " 35,\n", + " 16,\n", + " 16,\n", + " 15,\n", + " 6,\n", + " 29,\n", + " 16,\n", + " 15,\n", + " 5,\n", + " 5,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 16,\n", + " 16,\n", + " 15,\n", + " 15,\n", + " 31,\n", + " 16,\n", + " 15,\n", + " 16,\n", + " 15,\n", + " 6,\n", + " 16,\n", + " 3,\n", + " 15,\n", + " 2,\n", + " 15,\n", + " 15,\n", + " 28,\n", + " 17,\n", + " 15,\n", + " 15,\n", + " 16,\n", + " 15,\n", + " 15,\n", + " 10,\n", + " 15,\n", + " 5,\n", + " 16,\n", + " 15,\n", + " 15,\n", + " 17,\n", + " 15,\n", + " 5,\n", + " 15,\n", + " 3,\n", + " 15,\n", + " 2,\n", + " 15,\n", + " 15,\n", + " 6,\n", + " 15,\n", + " 28,\n", + " 15,\n", + " 6,\n", + " 15,\n", + " 32,\n", + " 16,\n", + " 15,\n", + " 2,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 16,\n", + " 17,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 16,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 35,\n", + " 15,\n", + " 15,\n", + " 35,\n", + " 16,\n", + " 28,\n", + " 15,\n", + " 15,\n", + " 15,\n", + " 5,\n", + " 15,\n", + " 15,\n", + " 19,\n", + " 15]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "pred = predict(model, data_loader, device)" + "pred" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "import pickle \n", + "encoder_filename = \"models/label_encoder.pkl\"\n", + "with open(path+encoder_filename, 'rb') as file:\n", + " encoder = pickle.load(file)" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "p2 = list(encoder.inverse_transform(pred))" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "df_LGE['class_bert'] = p2" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>id</th>\n", + " <th>tome</th>\n", + " <th>rank</th>\n", + " <th>domain</th>\n", + " <th>remark</th>\n", + " <th>content</th>\n", + " <th>class_bert</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>abrabeses-0</td>\n", + " <td>1</td>\n", + " <td>623</td>\n", + " <td>geography</td>\n", + " <td>NaN</td>\n", + " <td>ABRABESES. Village d’Espagne de la prov. de Za...</td>\n", + " <td>Géographie</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>accius-0</td>\n", + " <td>1</td>\n", + " <td>1076</td>\n", + " <td>biography</td>\n", + " <td>NaN</td>\n", + " <td>ACCIUS, L. ou L. ATTIUS (170-94 av. J.-C.), po...</td>\n", + " <td>Belles-lettres - Poésie</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>achenbach-2</td>\n", + " <td>1</td>\n", + " <td>1357</td>\n", + " <td>biography</td>\n", + " <td>NaN</td>\n", + " <td>ACHENBACH(Henri), administrateur prussien, né ...</td>\n", + " <td>Histoire</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>acireale-0</td>\n", + " <td>1</td>\n", + " <td>1513</td>\n", + " <td>geography</td>\n", + " <td>NaN</td>\n", + " <td>ACIREALE. Yille de Sicile, de la province et d...</td>\n", + " <td>Géographie</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>actée-0</td>\n", + " <td>1</td>\n", + " <td>1731</td>\n", + " <td>botany</td>\n", + " <td>NaN</td>\n", + " <td>ACTÉE(Actœa L.). Genre de plantes de la famill...</td>\n", + " <td>Histoire naturelle</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " id tome rank domain remark \\\n", + "0 abrabeses-0 1 623 geography NaN \n", + "1 accius-0 1 1076 biography NaN \n", + "2 achenbach-2 1 1357 biography NaN \n", + "3 acireale-0 1 1513 geography NaN \n", + "4 actée-0 1 1731 botany NaN \n", + "\n", + " content class_bert \n", + "0 ABRABESES. Village d’Espagne de la prov. de Za... Géographie \n", + "1 ACCIUS, L. ou L. ATTIUS (170-94 av. J.-C.), po... Belles-lettres - Poésie \n", + "2 ACHENBACH(Henri), administrateur prussien, né ... Histoire \n", + "3 ACIREALE. Yille de Sicile, de la province et d... Géographie \n", + "4 ACTÉE(Actœa L.). Genre de plantes de la famill... Histoire naturelle " + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_LGE.head()" + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "df_LGE.to_csv(path + \"reports/classification_LGE.tsv\", sep=\"\\t\")" + ] } ], "metadata": {