diff --git a/notebooks/Predict.ipynb b/notebooks/Predict.ipynb
index f479b3e92a836387e0e59adc5e671e6e587344be..76e7e0ac071aed40627043e8e26de73460c8d9b4 100644
--- a/notebooks/Predict.ipynb
+++ b/notebooks/Predict.ipynb
@@ -139,7 +139,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 23,
+      "execution_count": 2,
       "metadata": {
         "id": "SkErnwgMMbRj"
       },
@@ -330,7 +330,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 7,
+      "execution_count": 14,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -341,7 +341,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 8,
+      "execution_count": 15,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -445,7 +445,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 12,
+      "execution_count": 16,
       "metadata": {},
       "outputs": [
         {
@@ -467,7 +467,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 13,
+      "execution_count": 17,
       "metadata": {},
       "outputs": [
         {
@@ -494,22 +494,9 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 26,
+      "execution_count": 18,
       "metadata": {},
-      "outputs": [
-        {
-          "ename": "TypeError",
-          "evalue": "Expected state_dict to be dict-like, got <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'>.",
-          "output_type": "error",
-          "traceback": [
-            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
-            "Cell \u001b[0;32mIn [26], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[39m#model = torch.load(model_path, map_location=torch.device('mps'))\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m model\u001b[39m.\u001b[39;49mload_state_dict(torch\u001b[39m.\u001b[39;49mload(model_path, map_location\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mdevice(\u001b[39m'\u001b[39;49m\u001b[39mmps\u001b[39;49m\u001b[39m'\u001b[39;49m)))\n",
-            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1620\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m   1597\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"Copies parameters and buffers from :attr:`state_dict` into\u001b[39;00m\n\u001b[1;32m   1598\u001b[0m \u001b[39mthis module and its descendants. If :attr:`strict` is ``True``, then\u001b[39;00m\n\u001b[1;32m   1599\u001b[0m \u001b[39mthe keys of :attr:`state_dict` must exactly match the keys returned\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1617\u001b[0m \u001b[39m    ``RuntimeError``.\u001b[39;00m\n\u001b[1;32m   1618\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   1619\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(state_dict, Mapping):\n\u001b[0;32m-> 1620\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mExpected state_dict to be dict-like, got \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\u001b[39mtype\u001b[39m(state_dict)))\n\u001b[1;32m   1622\u001b[0m missing_keys: List[\u001b[39mstr\u001b[39m] \u001b[39m=\u001b[39m []\n\u001b[1;32m   1623\u001b[0m unexpected_keys: List[\u001b[39mstr\u001b[39m] \u001b[39m=\u001b[39m []\n",
-            "\u001b[0;31mTypeError\u001b[0m: Expected state_dict to be dict-like, got <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'>."
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "#model = torch.load(model_path, map_location=torch.device('mps'))\n",
         "#model.load_state_dict(torch.load(model_path, map_location=torch.device('mps')))\n",
@@ -519,7 +506,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 14,
+      "execution_count": 19,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
@@ -527,65 +514,493 @@
         "id": "_fzgS5USJeAF",
         "outputId": "be4a5506-76ed-4eef-bb3c-fe2bb77c6e4d"
       },
+      "outputs": [],
+      "source": [
+        "pred = predict(model, data_loader, device)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 20,
+      "metadata": {},
       "outputs": [
         {
-          "ename": "AttributeError",
-          "evalue": "'BertEncoder' object has no attribute 'gradient_checkpointing'",
-          "output_type": "error",
-          "traceback": [
-            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-            "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
-            "Cell \u001b[0;32mIn [14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m pred \u001b[39m=\u001b[39m predict(model, data_loader, device)\n",
-            "Cell \u001b[0;32mIn [8], line 68\u001b[0m, in \u001b[0;36mpredict\u001b[0;34m(model, dataloader, device)\u001b[0m\n\u001b[1;32m     64\u001b[0m \u001b[39m# Telling the model not to compute or store gradients, saving memory and\u001b[39;00m\n\u001b[1;32m     65\u001b[0m \u001b[39m# speeding up prediction\u001b[39;00m\n\u001b[1;32m     66\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[1;32m     67\u001b[0m     \u001b[39m# Forward pass, calculate logit predictions\u001b[39;00m\n\u001b[0;32m---> 68\u001b[0m     outputs \u001b[39m=\u001b[39m model(b_input_ids, token_type_ids\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m     69\u001b[0m                     attention_mask\u001b[39m=\u001b[39;49mb_input_mask)\n\u001b[1;32m     71\u001b[0m logits \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m     72\u001b[0m \u001b[39m#print(logits)\u001b[39;00m\n\u001b[1;32m     73\u001b[0m \n\u001b[1;32m     74\u001b[0m \u001b[39m# Move logits and labels to CPU ???\u001b[39;00m\n",
-            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1191\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
-            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:1552\u001b[0m, in \u001b[0;36mBertForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1544\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   1545\u001b[0m \u001b[39mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m   1546\u001b[0m \u001b[39m    Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[1;32m   1547\u001b[0m \u001b[39m    config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\u001b[39;00m\n\u001b[1;32m   1548\u001b[0m \u001b[39m    `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\u001b[39;00m\n\u001b[1;32m   1549\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   1550\u001b[0m return_dict \u001b[39m=\u001b[39m return_dict \u001b[39mif\u001b[39;00m return_dict \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39muse_return_dict\n\u001b[0;32m-> 1552\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbert(\n\u001b[1;32m   1553\u001b[0m     input_ids,\n\u001b[1;32m   1554\u001b[0m     attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m   1555\u001b[0m     token_type_ids\u001b[39m=\u001b[39;49mtoken_type_ids,\n\u001b[1;32m   1556\u001b[0m     position_ids\u001b[39m=\u001b[39;49mposition_ids,\n\u001b[1;32m   1557\u001b[0m     head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m   1558\u001b[0m     inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m   1559\u001b[0m     output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m   1560\u001b[0m     output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m   1561\u001b[0m     return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m   1562\u001b[0m )\n\u001b[1;32m   1564\u001b[0m pooled_output \u001b[39m=\u001b[39m outputs[\u001b[39m1\u001b[39m]\n\u001b[1;32m   1566\u001b[0m pooled_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdropout(pooled_output)\n",
-            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1191\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
-            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:1014\u001b[0m, in \u001b[0;36mBertModel.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1005\u001b[0m head_mask \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_head_mask(head_mask, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mnum_hidden_layers)\n\u001b[1;32m   1007\u001b[0m embedding_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings(\n\u001b[1;32m   1008\u001b[0m     input_ids\u001b[39m=\u001b[39minput_ids,\n\u001b[1;32m   1009\u001b[0m     position_ids\u001b[39m=\u001b[39mposition_ids,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1012\u001b[0m     past_key_values_length\u001b[39m=\u001b[39mpast_key_values_length,\n\u001b[1;32m   1013\u001b[0m )\n\u001b[0;32m-> 1014\u001b[0m encoder_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mencoder(\n\u001b[1;32m   1015\u001b[0m     embedding_output,\n\u001b[1;32m   1016\u001b[0m     attention_mask\u001b[39m=\u001b[39;49mextended_attention_mask,\n\u001b[1;32m   1017\u001b[0m     head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m   1018\u001b[0m     encoder_hidden_states\u001b[39m=\u001b[39;49mencoder_hidden_states,\n\u001b[1;32m   1019\u001b[0m     encoder_attention_mask\u001b[39m=\u001b[39;49mencoder_extended_attention_mask,\n\u001b[1;32m   1020\u001b[0m     past_key_values\u001b[39m=\u001b[39;49mpast_key_values,\n\u001b[1;32m   1021\u001b[0m     use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m   1022\u001b[0m     output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m   1023\u001b[0m     output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m   1024\u001b[0m     return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m   1025\u001b[0m )\n\u001b[1;32m   1026\u001b[0m sequence_output \u001b[39m=\u001b[39m encoder_outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m   1027\u001b[0m pooled_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpooler(sequence_output) \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpooler \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n",
-            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1191\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
-            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:580\u001b[0m, in \u001b[0;36mBertEncoder.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    577\u001b[0m layer_head_mask \u001b[39m=\u001b[39m head_mask[i] \u001b[39mif\u001b[39;00m head_mask \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m    578\u001b[0m past_key_value \u001b[39m=\u001b[39m past_key_values[i] \u001b[39mif\u001b[39;00m past_key_values \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m--> 580\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgradient_checkpointing \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining:\n\u001b[1;32m    582\u001b[0m     \u001b[39mif\u001b[39;00m use_cache:\n\u001b[1;32m    583\u001b[0m         logger\u001b[39m.\u001b[39mwarning(\n\u001b[1;32m    584\u001b[0m             \u001b[39m\"\u001b[39m\u001b[39m`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    585\u001b[0m         )\n",
-            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1265\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1263\u001b[0m     \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m modules:\n\u001b[1;32m   1264\u001b[0m         \u001b[39mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1265\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m object has no attribute \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m   1266\u001b[0m     \u001b[39mtype\u001b[39m(\u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, name))\n",
-            "\u001b[0;31mAttributeError\u001b[0m: 'BertEncoder' object has no attribute 'gradient_checkpointing'"
-          ]
+          "data": {
+            "text/plain": [
+              "[15,\n",
+              " 6,\n",
+              " 16,\n",
+              " 15,\n",
+              " 17,\n",
+              " 10,\n",
+              " 17,\n",
+              " 16,\n",
+              " 19,\n",
+              " 35,\n",
+              " 15,\n",
+              " 26,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 2,\n",
+              " 2,\n",
+              " 17,\n",
+              " 6,\n",
+              " 32,\n",
+              " 17,\n",
+              " 30,\n",
+              " 16,\n",
+              " 32,\n",
+              " 15,\n",
+              " 35,\n",
+              " 15,\n",
+              " 23,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 17,\n",
+              " 15,\n",
+              " 16,\n",
+              " 3,\n",
+              " 17,\n",
+              " 17,\n",
+              " 16,\n",
+              " 4,\n",
+              " 15,\n",
+              " 17,\n",
+              " 19,\n",
+              " 16,\n",
+              " 35,\n",
+              " 3,\n",
+              " 17,\n",
+              " 5,\n",
+              " 15,\n",
+              " 16,\n",
+              " 16,\n",
+              " 15,\n",
+              " 16,\n",
+              " 6,\n",
+              " 16,\n",
+              " 5,\n",
+              " 16,\n",
+              " 15,\n",
+              " 28,\n",
+              " 16,\n",
+              " 17,\n",
+              " 10,\n",
+              " 15,\n",
+              " 15,\n",
+              " 32,\n",
+              " 15,\n",
+              " 17,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 12,\n",
+              " 15,\n",
+              " 18,\n",
+              " 15,\n",
+              " 35,\n",
+              " 26,\n",
+              " 16,\n",
+              " 16,\n",
+              " 15,\n",
+              " 5,\n",
+              " 15,\n",
+              " 15,\n",
+              " 5,\n",
+              " 17,\n",
+              " 15,\n",
+              " 17,\n",
+              " 35,\n",
+              " 15,\n",
+              " 16,\n",
+              " 16,\n",
+              " 17,\n",
+              " 2,\n",
+              " 17,\n",
+              " 15,\n",
+              " 16,\n",
+              " 23,\n",
+              " 16,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 16,\n",
+              " 6,\n",
+              " 15,\n",
+              " 35,\n",
+              " 15,\n",
+              " 32,\n",
+              " 16,\n",
+              " 6,\n",
+              " 16,\n",
+              " 23,\n",
+              " 36,\n",
+              " 5,\n",
+              " 35,\n",
+              " 3,\n",
+              " 3,\n",
+              " 3,\n",
+              " 16,\n",
+              " 17,\n",
+              " 2,\n",
+              " 15,\n",
+              " 5,\n",
+              " 17,\n",
+              " 16,\n",
+              " 15,\n",
+              " 17,\n",
+              " 6,\n",
+              " 15,\n",
+              " 16,\n",
+              " 10,\n",
+              " 16,\n",
+              " 15,\n",
+              " 35,\n",
+              " 17,\n",
+              " 15,\n",
+              " 15,\n",
+              " 6,\n",
+              " 28,\n",
+              " 16,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 16,\n",
+              " 5,\n",
+              " 15,\n",
+              " 21,\n",
+              " 5,\n",
+              " 1,\n",
+              " 7,\n",
+              " 16,\n",
+              " 15,\n",
+              " 17,\n",
+              " 23,\n",
+              " 15,\n",
+              " 5,\n",
+              " 0,\n",
+              " 10,\n",
+              " 16,\n",
+              " 16,\n",
+              " 15,\n",
+              " 16,\n",
+              " 15,\n",
+              " 15,\n",
+              " 3,\n",
+              " 3,\n",
+              " 17,\n",
+              " 36,\n",
+              " 16,\n",
+              " 15,\n",
+              " 12,\n",
+              " 6,\n",
+              " 15,\n",
+              " 4,\n",
+              " 16,\n",
+              " 16,\n",
+              " 26,\n",
+              " 15,\n",
+              " 15,\n",
+              " 32,\n",
+              " 15,\n",
+              " 10,\n",
+              " 15,\n",
+              " 5,\n",
+              " 26,\n",
+              " 5,\n",
+              " 15,\n",
+              " 15,\n",
+              " 26,\n",
+              " 15,\n",
+              " 35,\n",
+              " 15,\n",
+              " 16,\n",
+              " 16,\n",
+              " 15,\n",
+              " 6,\n",
+              " 16,\n",
+              " 12,\n",
+              " 16,\n",
+              " 28,\n",
+              " 16,\n",
+              " 15,\n",
+              " 15,\n",
+              " 16,\n",
+              " 6,\n",
+              " 10,\n",
+              " 16,\n",
+              " 15,\n",
+              " 15,\n",
+              " 16,\n",
+              " 16,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 5,\n",
+              " 16,\n",
+              " 16,\n",
+              " 17,\n",
+              " 15,\n",
+              " 16,\n",
+              " 35,\n",
+              " 16,\n",
+              " 16,\n",
+              " 15,\n",
+              " 6,\n",
+              " 29,\n",
+              " 16,\n",
+              " 15,\n",
+              " 5,\n",
+              " 5,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 16,\n",
+              " 16,\n",
+              " 15,\n",
+              " 15,\n",
+              " 31,\n",
+              " 16,\n",
+              " 15,\n",
+              " 16,\n",
+              " 15,\n",
+              " 6,\n",
+              " 16,\n",
+              " 3,\n",
+              " 15,\n",
+              " 2,\n",
+              " 15,\n",
+              " 15,\n",
+              " 28,\n",
+              " 17,\n",
+              " 15,\n",
+              " 15,\n",
+              " 16,\n",
+              " 15,\n",
+              " 15,\n",
+              " 10,\n",
+              " 15,\n",
+              " 5,\n",
+              " 16,\n",
+              " 15,\n",
+              " 15,\n",
+              " 17,\n",
+              " 15,\n",
+              " 5,\n",
+              " 15,\n",
+              " 3,\n",
+              " 15,\n",
+              " 2,\n",
+              " 15,\n",
+              " 15,\n",
+              " 6,\n",
+              " 15,\n",
+              " 28,\n",
+              " 15,\n",
+              " 6,\n",
+              " 15,\n",
+              " 32,\n",
+              " 16,\n",
+              " 15,\n",
+              " 2,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 16,\n",
+              " 17,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 16,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 35,\n",
+              " 15,\n",
+              " 15,\n",
+              " 35,\n",
+              " 16,\n",
+              " 28,\n",
+              " 15,\n",
+              " 15,\n",
+              " 15,\n",
+              " 5,\n",
+              " 15,\n",
+              " 15,\n",
+              " 19,\n",
+              " 15]"
+            ]
+          },
+          "execution_count": 20,
+          "metadata": {},
+          "output_type": "execute_result"
         }
       ],
       "source": [
-        "pred = predict(model, data_loader, device)"
+        "pred"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 22,
       "metadata": {},
       "outputs": [],
-      "source": []
+      "source": [
+        "import pickle \n",
+        "encoder_filename = \"models/label_encoder.pkl\"\n",
+        "with open(path+encoder_filename, 'rb') as file:\n",
+        "      encoder = pickle.load(file)"
+      ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 23,
       "metadata": {},
       "outputs": [],
-      "source": []
+      "source": [
+        "p2 = list(encoder.inverse_transform(pred))"
+      ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 24,
       "metadata": {},
       "outputs": [],
-      "source": []
+      "source": [
+        "df_LGE['class_bert'] = p2"
+      ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 25,
       "metadata": {},
-      "outputs": [],
-      "source": []
+      "outputs": [
+        {
+          "data": {
+            "text/html": [
+              "<div>\n",
+              "<style scoped>\n",
+              "    .dataframe tbody tr th:only-of-type {\n",
+              "        vertical-align: middle;\n",
+              "    }\n",
+              "\n",
+              "    .dataframe tbody tr th {\n",
+              "        vertical-align: top;\n",
+              "    }\n",
+              "\n",
+              "    .dataframe thead th {\n",
+              "        text-align: right;\n",
+              "    }\n",
+              "</style>\n",
+              "<table border=\"1\" class=\"dataframe\">\n",
+              "  <thead>\n",
+              "    <tr style=\"text-align: right;\">\n",
+              "      <th></th>\n",
+              "      <th>id</th>\n",
+              "      <th>tome</th>\n",
+              "      <th>rank</th>\n",
+              "      <th>domain</th>\n",
+              "      <th>remark</th>\n",
+              "      <th>content</th>\n",
+              "      <th>class_bert</th>\n",
+              "    </tr>\n",
+              "  </thead>\n",
+              "  <tbody>\n",
+              "    <tr>\n",
+              "      <th>0</th>\n",
+              "      <td>abrabeses-0</td>\n",
+              "      <td>1</td>\n",
+              "      <td>623</td>\n",
+              "      <td>geography</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>ABRABESES. Village d’Espagne de la prov. de Za...</td>\n",
+              "      <td>Géographie</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>1</th>\n",
+              "      <td>accius-0</td>\n",
+              "      <td>1</td>\n",
+              "      <td>1076</td>\n",
+              "      <td>biography</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>ACCIUS, L. ou L. ATTIUS (170-94 av. J.-C.), po...</td>\n",
+              "      <td>Belles-lettres - Poésie</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>2</th>\n",
+              "      <td>achenbach-2</td>\n",
+              "      <td>1</td>\n",
+              "      <td>1357</td>\n",
+              "      <td>biography</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>ACHENBACH(Henri), administrateur prussien, né ...</td>\n",
+              "      <td>Histoire</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>3</th>\n",
+              "      <td>acireale-0</td>\n",
+              "      <td>1</td>\n",
+              "      <td>1513</td>\n",
+              "      <td>geography</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>ACIREALE. Yille de Sicile, de la province et d...</td>\n",
+              "      <td>Géographie</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>4</th>\n",
+              "      <td>actée-0</td>\n",
+              "      <td>1</td>\n",
+              "      <td>1731</td>\n",
+              "      <td>botany</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>ACTÉE(Actœa L.). Genre de plantes de la famill...</td>\n",
+              "      <td>Histoire naturelle</td>\n",
+              "    </tr>\n",
+              "  </tbody>\n",
+              "</table>\n",
+              "</div>"
+            ],
+            "text/plain": [
+              "            id  tome  rank     domain remark  \\\n",
+              "0  abrabeses-0     1   623  geography    NaN   \n",
+              "1     accius-0     1  1076  biography    NaN   \n",
+              "2  achenbach-2     1  1357  biography    NaN   \n",
+              "3   acireale-0     1  1513  geography    NaN   \n",
+              "4      actée-0     1  1731     botany    NaN   \n",
+              "\n",
+              "                                             content               class_bert  \n",
+              "0  ABRABESES. Village d’Espagne de la prov. de Za...               Géographie  \n",
+              "1  ACCIUS, L. ou L. ATTIUS (170-94 av. J.-C.), po...  Belles-lettres - Poésie  \n",
+              "2  ACHENBACH(Henri), administrateur prussien, né ...                 Histoire  \n",
+              "3  ACIREALE. Yille de Sicile, de la province et d...               Géographie  \n",
+              "4  ACTÉE(Actœa L.). Genre de plantes de la famill...       Histoire naturelle  "
+            ]
+          },
+          "execution_count": 25,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
+      "source": [
+        "df_LGE.head()"
+      ]
     },
     {
       "cell_type": "code",
       "execution_count": null,
       "metadata": {},
       "outputs": [],
-      "source": []
+      "source": [
+        "df_LGE.to_csv(path + \"reports/classification_LGE.tsv\", sep=\"\\t\")"
+      ]
     }
   ],
   "metadata": {