diff --git a/notebooks/Predict.ipynb b/notebooks/Predict.ipynb
index 441b5242a5056e9207e8fa2effaa7bbae5ead680..8fa3a5c6c593be9e5ade5bb38e29f1fee7582420 100644
--- a/notebooks/Predict.ipynb
+++ b/notebooks/Predict.ipynb
@@ -186,7 +186,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 4,
+      "execution_count": 3,
       "metadata": {
         "id": "M2awiee1r0zV"
       },
@@ -208,7 +208,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 5,
+      "execution_count": 4,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/",
@@ -325,7 +325,7 @@
               "4  A(Numis.). Dans la numismatique grecque, la le...         67.0  "
             ]
           },
-          "execution_count": 5,
+          "execution_count": 4,
           "metadata": {},
           "output_type": "execute_result"
         }
@@ -359,7 +359,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 7,
+      "execution_count": 5,
       "metadata": {
         "id": "0qDZ86qTr0zX"
       },
@@ -374,7 +374,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 8,
+      "execution_count": 6,
       "metadata": {
         "id": "KEljGX0br0zX"
       },
@@ -475,12 +475,16 @@
         "            pred_labels.append(pred_labels_i)\n",
         "\n",
         "    pred_labels_ += [item for sublist in pred_labels for item in sublist]\n",
-        "    return pred_labels_"
+        "    return pred_labels_\n",
+        "\n",
+        "\n",
+        "\n",
+        "#https://discuss.huggingface.co/t/i-have-trained-my-classifier-now-how-do-i-do-predictions/3625/3\n"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 9,
+      "execution_count": 7,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/",
@@ -586,7 +590,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 11,
+      "execution_count": 10,
       "metadata": {
         "id": "CN8EZst-r0zZ"
       },
@@ -596,7 +600,7 @@
         "#model.load_state_dict(torch.load(model_path, map_location=torch.device('mps')))\n",
         "\n",
         "#model = BertForSequenceClassification.from_pretrained(model_path).to(\"cuda\")\n",
-        "model = BertForSequenceClassification.from_pretrained(model_path).to(\"mps\")"
+        "model = BertForSequenceClassification.from_pretrained(model_path).to(\"cpu\")"
       ]
     },
     {
@@ -610,6 +614,100 @@
         "pred = predict(model, data_loader, device)"
       ]
     },
+    {
+      "cell_type": "code",
+      "execution_count": 13,
+      "metadata": {},
+      "outputs": [
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/text_classification.py:89: UserWarning: `return_all_scores` is now deprecated,  if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n",
+            "  warnings.warn(\n"
+          ]
+        },
+        {
+          "data": {
+            "text/plain": [
+              "[[{'label': 'LABEL_0', 'score': 9.58614400587976e-05},\n",
+              "  {'label': 'LABEL_1', 'score': 0.00010365043272031471},\n",
+              "  {'label': 'LABEL_2', 'score': 6.21283397777006e-05},\n",
+              "  {'label': 'LABEL_3', 'score': 9.175329614663497e-05},\n",
+              "  {'label': 'LABEL_4', 'score': 9.065424819709733e-05},\n",
+              "  {'label': 'LABEL_5', 'score': 0.00010455227311467752},\n",
+              "  {'label': 'LABEL_6', 'score': 0.9985577464103699},\n",
+              "  {'label': 'LABEL_7', 'score': 0.00013558757200371474},\n",
+              "  {'label': 'LABEL_8', 'score': 0.0001018877956084907},\n",
+              "  {'label': 'LABEL_9', 'score': 0.0001431443088222295},\n",
+              "  {'label': 'LABEL_10', 'score': 0.00010823880438692868},\n",
+              "  {'label': 'LABEL_11', 'score': 3.7985137169016525e-05},\n",
+              "  {'label': 'LABEL_12', 'score': 6.803833093727008e-05},\n",
+              "  {'label': 'LABEL_13', 'score': 4.024818554171361e-05},\n",
+              "  {'label': 'LABEL_14', 'score': 0.0001047810583258979},\n",
+              "  {'label': 'LABEL_15', 'score': 8.337549661519006e-05},\n",
+              "  {'label': 'LABEL_16', 'score': 7.031656423350796e-05}]]"
+            ]
+          },
+          "execution_count": 13,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
+      "source": [
+        "## TEST\n",
+        "\n",
+        "\n",
+        "from transformers import TextClassificationPipeline\n",
+        "\n",
+        "pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)\n",
+        "# outputs a list of dicts like [[{'label': 'NEGATIVE', 'score': 0.0001223755971295759},  {'label': 'POSITIVE', 'score': 0.9998776316642761}]]\n",
+        "prob = pipe(\"Lyon, petite ville de France, dans la région Rhone-Alpes.\")\n",
+        "prob"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 19,
+      "metadata": {},
+      "outputs": [
+        {
+          "data": {
+            "text/plain": [
+              "'0'"
+            ]
+          },
+          "execution_count": 19,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
+      "source": [
+        "prob[0][0]['label'][6:]"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 22,
+      "metadata": {},
+      "outputs": [
+        {
+          "data": {
+            "text/plain": [
+              "array(['Agriculture'], dtype=object)"
+            ]
+          },
+          "execution_count": 22,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
+      "source": [
+        "## TEST\n",
+        "\n",
+        "encoder.inverse_transform([int(prob[0][0]['label'][6:])])"
+      ]
+    },
     {
       "cell_type": "code",
       "execution_count": 13,
@@ -1638,7 +1736,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 14,
+      "execution_count": 12,
       "metadata": {
         "id": "fo6k4li1r0za"
       },
@@ -2200,7 +2298,7 @@
       "name": "python",
       "nbconvert_exporter": "python",
       "pygments_lexer": "ipython3",
-      "version": "3.9.15"
+      "version": "3.9.13"
     },
     "vscode": {
       "interpreter": {