From 05ac721c696fb1e7ba8697fb5faef3a19505da92 Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <moncla.ludovic@gmail.com>
Date: Mon, 20 Mar 2023 20:16:20 +0100
Subject: [PATCH] Update Predict.ipynb

---
 notebooks/Predict.ipynb | 83 +++++++++++++++++++++++------------------
 1 file changed, 47 insertions(+), 36 deletions(-)

diff --git a/notebooks/Predict.ipynb b/notebooks/Predict.ipynb
index 8fa3a5c..91a7f75 100644
--- a/notebooks/Predict.ipynb
+++ b/notebooks/Predict.ipynb
@@ -590,7 +590,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 10,
+      "execution_count": 23,
       "metadata": {
         "id": "CN8EZst-r0zZ"
       },
@@ -600,7 +600,7 @@
         "#model.load_state_dict(torch.load(model_path, map_location=torch.device('mps')))\n",
         "\n",
         "#model = BertForSequenceClassification.from_pretrained(model_path).to(\"cuda\")\n",
-        "model = BertForSequenceClassification.from_pretrained(model_path).to(\"cpu\")"
+        "model = BertForSequenceClassification.from_pretrained(model_path).to(\"mps\")"
       ]
     },
     {
@@ -616,7 +616,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 13,
+      "execution_count": 33,
       "metadata": {},
       "outputs": [
         {
@@ -626,49 +626,60 @@
             "/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/text_classification.py:89: UserWarning: `return_all_scores` is now deprecated,  if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n",
             "  warnings.warn(\n"
           ]
-        },
-        {
-          "data": {
-            "text/plain": [
-              "[[{'label': 'LABEL_0', 'score': 9.58614400587976e-05},\n",
-              "  {'label': 'LABEL_1', 'score': 0.00010365043272031471},\n",
-              "  {'label': 'LABEL_2', 'score': 6.21283397777006e-05},\n",
-              "  {'label': 'LABEL_3', 'score': 9.175329614663497e-05},\n",
-              "  {'label': 'LABEL_4', 'score': 9.065424819709733e-05},\n",
-              "  {'label': 'LABEL_5', 'score': 0.00010455227311467752},\n",
-              "  {'label': 'LABEL_6', 'score': 0.9985577464103699},\n",
-              "  {'label': 'LABEL_7', 'score': 0.00013558757200371474},\n",
-              "  {'label': 'LABEL_8', 'score': 0.0001018877956084907},\n",
-              "  {'label': 'LABEL_9', 'score': 0.0001431443088222295},\n",
-              "  {'label': 'LABEL_10', 'score': 0.00010823880438692868},\n",
-              "  {'label': 'LABEL_11', 'score': 3.7985137169016525e-05},\n",
-              "  {'label': 'LABEL_12', 'score': 6.803833093727008e-05},\n",
-              "  {'label': 'LABEL_13', 'score': 4.024818554171361e-05},\n",
-              "  {'label': 'LABEL_14', 'score': 0.0001047810583258979},\n",
-              "  {'label': 'LABEL_15', 'score': 8.337549661519006e-05},\n",
-              "  {'label': 'LABEL_16', 'score': 7.031656423350796e-05}]]"
-            ]
-          },
-          "execution_count": 13,
-          "metadata": {},
-          "output_type": "execute_result"
         }
       ],
       "source": [
         "## TEST\n",
+        "# https://huggingface.co/docs/transformers/main_classes/pipelines\n",
+        "from transformers import TextClassificationPipeline\n",
         "\n",
+        "def data():\n",
+        "    for i in range(1000):\n",
+        "        yield f\"Lyon, petite ville de France. {i}\"\n",
         "\n",
-        "from transformers import TextClassificationPipeline\n",
         "\n",
-        "pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)\n",
-        "# outputs a list of dicts like [[{'label': 'NEGATIVE', 'score': 0.0001223755971295759},  {'label': 'POSITIVE', 'score': 0.9998776316642761}]]\n",
-        "prob = pipe(\"Lyon, petite ville de France, dans la région Rhone-Alpes.\")\n",
-        "prob"
+        "pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True, device=device)\n"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 19,
+      "execution_count": 35,
+      "metadata": {},
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "[{'label': 'LABEL_0', 'score': 9.43058475968428e-05}, {'label': 'LABEL_1', 'score': 0.00013377856521401554}, {'label': 'LABEL_2', 'score': 6.315444625215605e-05}, {'label': 'LABEL_3', 'score': 9.087997023016214e-05}, {'label': 'LABEL_4', 'score': 0.00012772278569173068}, {'label': 'LABEL_5', 'score': 0.00012729596346616745}, {'label': 'LABEL_6', 'score': 0.9983708262443542}, {'label': 'LABEL_7', 'score': 0.00015073739632498473}, {'label': 'LABEL_8', 'score': 0.00013310853682924062}, {'label': 'LABEL_9', 'score': 0.0001363410265184939}, {'label': 'LABEL_10', 'score': 0.00011535766680026427}, {'label': 'LABEL_11', 'score': 4.8044770665001124e-05}, {'label': 'LABEL_12', 'score': 7.562591781606898e-05}, {'label': 'LABEL_13', 'score': 4.6062668843660504e-05}, {'label': 'LABEL_14', 'score': 0.00012537441216409206}, {'label': 'LABEL_15', 'score': 9.473998215980828e-05}, {'label': 'LABEL_16', 'score': 6.669617869192734e-05}]\n",
+            "6\n",
+            "[{'label': 'LABEL_0', 'score': 9.85840815701522e-05}, {'label': 'LABEL_1', 'score': 0.0001410262193530798}, {'label': 'LABEL_2', 'score': 6.340965774143115e-05}, {'label': 'LABEL_3', 'score': 9.572453564032912e-05}, {'label': 'LABEL_4', 'score': 0.00011747579992515966}, {'label': 'LABEL_5', 'score': 0.00012954592239111662}, {'label': 'LABEL_6', 'score': 0.9982858300209045}, {'label': 'LABEL_7', 'score': 0.0001560843811603263}, {'label': 'LABEL_8', 'score': 0.00015996283036656678}, {'label': 'LABEL_9', 'score': 0.0001614005013834685}, {'label': 'LABEL_10', 'score': 0.00010834677232196555}, {'label': 'LABEL_11', 'score': 4.9881378799909726e-05}, {'label': 'LABEL_12', 'score': 7.358138827839866e-05}, {'label': 'LABEL_13', 'score': 5.4664047638652846e-05}, {'label': 'LABEL_14', 'score': 0.00013466033851727843}, {'label': 'LABEL_15', 'score': 9.780169057194144e-05}, {'label': 'LABEL_16', 'score': 7.196604565251619e-05}]\n",
+            "6\n",
+            "[{'label': 'LABEL_0', 'score': 9.556901204632595e-05}, {'label': 'LABEL_1', 'score': 0.0001365469943266362}, {'label': 'LABEL_2', 'score': 6.268925790209323e-05}, {'label': 'LABEL_3', 'score': 9.737971413414925e-05}, {'label': 'LABEL_4', 'score': 0.00012014496314805001}, {'label': 'LABEL_5', 'score': 0.00012252115993760526}, {'label': 'LABEL_6', 'score': 0.9983487129211426}, {'label': 'LABEL_7', 'score': 0.0001454231096431613}, {'label': 'LABEL_8', 'score': 0.00014558130351360887}, {'label': 'LABEL_9', 'score': 0.00014958814426790923}, {'label': 'LABEL_10', 'score': 0.00011634181282715872}, {'label': 'LABEL_11', 'score': 4.5097345719113946e-05}, {'label': 'LABEL_12', 'score': 8.068335591815412e-05}, {'label': 'LABEL_13', 'score': 4.724525933852419e-05}, {'label': 'LABEL_14', 'score': 0.00012563375639729202}, {'label': 'LABEL_15', 'score': 9.24634441616945e-05}, {'label': 'LABEL_16', 'score': 6.83424441376701e-05}]\n",
+            "6\n",
+            "[{'label': 'LABEL_0', 'score': 9.575629519531503e-05}, {'label': 'LABEL_1', 'score': 0.00013479188783094287}, {'label': 'LABEL_2', 'score': 6.24070453341119e-05}, {'label': 'LABEL_3', 'score': 9.491511445958167e-05}, {'label': 'LABEL_4', 'score': 0.00011898632510565221}, {'label': 'LABEL_5', 'score': 0.00012223367230035365}, {'label': 'LABEL_6', 'score': 0.9983828067779541}, {'label': 'LABEL_7', 'score': 0.00014901417307555676}, {'label': 'LABEL_8', 'score': 0.0001293729292228818}, {'label': 'LABEL_9', 'score': 0.00014636504056397825}, {'label': 'LABEL_10', 'score': 0.00011709715909091756}, {'label': 'LABEL_11', 'score': 4.3970183469355106e-05}, {'label': 'LABEL_12', 'score': 7.832375558791682e-05}, {'label': 'LABEL_13', 'score': 4.6482757170451805e-05}, {'label': 'LABEL_14', 'score': 0.00011872482718899846}, {'label': 'LABEL_15', 'score': 9.005393803818151e-05}, {'label': 'LABEL_16', 'score': 6.87053834553808e-05}]\n",
+            "6\n",
+            "[{'label': 'LABEL_0', 'score': 9.33124974835664e-05}, {'label': 'LABEL_1', 'score': 0.00012642868387047201}, {'label': 'LABEL_2', 'score': 6.495929847005755e-05}, {'label': 'LABEL_3', 'score': 9.773051715455949e-05}, {'label': 'LABEL_4', 'score': 0.00011607634951360524}, {'label': 'LABEL_5', 'score': 0.00012188677646918222}, {'label': 'LABEL_6', 'score': 0.9983865022659302}, {'label': 'LABEL_7', 'score': 0.0001447165122954175}, {'label': 'LABEL_8', 'score': 0.00012925465125590563}, {'label': 'LABEL_9', 'score': 0.0001489764981670305}, {'label': 'LABEL_10', 'score': 0.0001232580398209393}, {'label': 'LABEL_11', 'score': 4.4239117414690554e-05}, {'label': 'LABEL_12', 'score': 7.944685057736933e-05}, {'label': 'LABEL_13', 'score': 4.5822369429515675e-05}, {'label': 'LABEL_14', 'score': 0.00011649943189695477}, {'label': 'LABEL_15', 'score': 9.088807564694434e-05}, {'label': 'LABEL_16', 'score': 6.998340541031212e-05}]\n",
+            "6\n",
+            "[{'label': 'LABEL_0', 'score': 9.538340236758813e-05}, {'label': 'LABEL_1', 'score': 0.00013363973994273692}, {'label': 'LABEL_2', 'score': 6.720751116517931e-05}, {'label': 'LABEL_3', 'score': 0.00010068194387713447}, {'label': 'LABEL_4', 'score': 0.00011288334644632414}, {'label': 'LABEL_5', 'score': 0.00012565024371724576}, {'label': 'LABEL_6', 'score': 0.9983444213867188}, {'label': 'LABEL_7', 'score': 0.00015267464914359152}, {'label': 'LABEL_8', 'score': 0.00014014744374435395}, {'label': 'LABEL_9', 'score': 0.00014672863471787423}, {'label': 'LABEL_10', 'score': 0.0001220486665260978}, {'label': 'LABEL_11', 'score': 4.699776036432013e-05}, {'label': 'LABEL_12', 'score': 7.61943738325499e-05}, {'label': 'LABEL_13', 'score': 4.92853214382194e-05}, {'label': 'LABEL_14', 'score': 0.00012135148426750675}, {'label': 'LABEL_15', 'score': 9.276873606722802e-05}, {'label': 'LABEL_16', 'score': 7.18621740816161e-05}]\n",
+            "6\n"
+          ]
+        }
+      ],
+      "source": [
+        "cpt = 0\n",
+        "for out in pipe(data()):\n",
+        "    print(out)\n",
+        "    # outputs a list of dicts like [[{'label': 'NEGATIVE', 'score': 0.0001223755971295759},  {'label': 'POSITIVE', 'score': 0.9998776316642761}]]\n",
+        "    # proba de la class Géographie : 6\n",
+        "    print(out[6]['label'][6:]) ### TODO modifier ici\n",
+        "    cpt += 1\n",
+        "    if cpt == 6:\n",
+        "        break\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 28,
       "metadata": {},
       "outputs": [
         {
@@ -677,7 +688,7 @@
               "'0'"
             ]
           },
-          "execution_count": 19,
+          "execution_count": 28,
           "metadata": {},
           "output_type": "execute_result"
         }
-- 
GitLab