diff --git a/notebooks/Predict.ipynb b/notebooks/Predict.ipynb index 8fa3a5c6c593be9e5ade5bb38e29f1fee7582420..91a7f75877346483b1e2155d44af5093e597477d 100644 --- a/notebooks/Predict.ipynb +++ b/notebooks/Predict.ipynb @@ -590,7 +590,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 23, "metadata": { "id": "CN8EZst-r0zZ" }, @@ -600,7 +600,7 @@ "#model.load_state_dict(torch.load(model_path, map_location=torch.device('mps')))\n", "\n", "#model = BertForSequenceClassification.from_pretrained(model_path).to(\"cuda\")\n", - "model = BertForSequenceClassification.from_pretrained(model_path).to(\"cpu\")" + "model = BertForSequenceClassification.from_pretrained(model_path).to(\"mps\")" ] }, { @@ -616,7 +616,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -626,49 +626,60 @@ "/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/text_classification.py:89: UserWarning: `return_all_scores` is now deprecated, if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n", " warnings.warn(\n" ] - }, - { - "data": { - "text/plain": [ - "[[{'label': 'LABEL_0', 'score': 9.58614400587976e-05},\n", - " {'label': 'LABEL_1', 'score': 0.00010365043272031471},\n", - " {'label': 'LABEL_2', 'score': 6.21283397777006e-05},\n", - " {'label': 'LABEL_3', 'score': 9.175329614663497e-05},\n", - " {'label': 'LABEL_4', 'score': 9.065424819709733e-05},\n", - " {'label': 'LABEL_5', 'score': 0.00010455227311467752},\n", - " {'label': 'LABEL_6', 'score': 0.9985577464103699},\n", - " {'label': 'LABEL_7', 'score': 0.00013558757200371474},\n", - " {'label': 'LABEL_8', 'score': 0.0001018877956084907},\n", - " {'label': 'LABEL_9', 'score': 0.0001431443088222295},\n", - " {'label': 'LABEL_10', 'score': 0.00010823880438692868},\n", - " {'label': 'LABEL_11', 'score': 3.7985137169016525e-05},\n", - " {'label': 'LABEL_12', 'score': 6.803833093727008e-05},\n", - " {'label': 'LABEL_13', 'score': 4.024818554171361e-05},\n", - " {'label': 'LABEL_14', 'score': 0.0001047810583258979},\n", - " {'label': 'LABEL_15', 'score': 8.337549661519006e-05},\n", - " {'label': 'LABEL_16', 'score': 7.031656423350796e-05}]]" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ "## TEST\n", + "# https://huggingface.co/docs/transformers/main_classes/pipelines\n", + "from transformers import TextClassificationPipeline\n", "\n", + "def data():\n", + " for i in range(1000):\n", + " yield f\"Lyon, petite ville de France. {i}\"\n", "\n", - "from transformers import TextClassificationPipeline\n", "\n", - "pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)\n", - "# outputs a list of dicts like [[{'label': 'NEGATIVE', 'score': 0.0001223755971295759}, {'label': 'POSITIVE', 'score': 0.9998776316642761}]]\n", - "prob = pipe(\"Lyon, petite ville de France, dans la région Rhone-Alpes.\")\n", - "prob" + "pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True, device=device)\n" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'label': 'LABEL_0', 'score': 9.43058475968428e-05}, {'label': 'LABEL_1', 'score': 0.00013377856521401554}, {'label': 'LABEL_2', 'score': 6.315444625215605e-05}, {'label': 'LABEL_3', 'score': 9.087997023016214e-05}, {'label': 'LABEL_4', 'score': 0.00012772278569173068}, {'label': 'LABEL_5', 'score': 0.00012729596346616745}, {'label': 'LABEL_6', 'score': 0.9983708262443542}, {'label': 'LABEL_7', 'score': 0.00015073739632498473}, {'label': 'LABEL_8', 'score': 0.00013310853682924062}, {'label': 'LABEL_9', 'score': 0.0001363410265184939}, {'label': 'LABEL_10', 'score': 0.00011535766680026427}, {'label': 'LABEL_11', 'score': 4.8044770665001124e-05}, {'label': 'LABEL_12', 'score': 7.562591781606898e-05}, {'label': 'LABEL_13', 'score': 4.6062668843660504e-05}, {'label': 'LABEL_14', 'score': 0.00012537441216409206}, {'label': 'LABEL_15', 'score': 9.473998215980828e-05}, {'label': 'LABEL_16', 'score': 6.669617869192734e-05}]\n", + "6\n", + "[{'label': 'LABEL_0', 'score': 9.85840815701522e-05}, {'label': 'LABEL_1', 'score': 0.0001410262193530798}, {'label': 'LABEL_2', 'score': 6.340965774143115e-05}, {'label': 'LABEL_3', 'score': 9.572453564032912e-05}, {'label': 'LABEL_4', 'score': 0.00011747579992515966}, {'label': 'LABEL_5', 'score': 0.00012954592239111662}, {'label': 'LABEL_6', 'score': 0.9982858300209045}, {'label': 'LABEL_7', 'score': 0.0001560843811603263}, {'label': 'LABEL_8', 'score': 0.00015996283036656678}, {'label': 'LABEL_9', 'score': 0.0001614005013834685}, {'label': 'LABEL_10', 'score': 0.00010834677232196555}, {'label': 'LABEL_11', 'score': 4.9881378799909726e-05}, {'label': 'LABEL_12', 'score': 7.358138827839866e-05}, {'label': 'LABEL_13', 'score': 5.4664047638652846e-05}, {'label': 'LABEL_14', 'score': 0.00013466033851727843}, {'label': 'LABEL_15', 'score': 9.780169057194144e-05}, {'label': 'LABEL_16', 'score': 7.196604565251619e-05}]\n", + "6\n", + "[{'label': 'LABEL_0', 'score': 9.556901204632595e-05}, {'label': 'LABEL_1', 'score': 0.0001365469943266362}, {'label': 'LABEL_2', 'score': 6.268925790209323e-05}, {'label': 'LABEL_3', 'score': 9.737971413414925e-05}, {'label': 'LABEL_4', 'score': 0.00012014496314805001}, {'label': 'LABEL_5', 'score': 0.00012252115993760526}, {'label': 'LABEL_6', 'score': 0.9983487129211426}, {'label': 'LABEL_7', 'score': 0.0001454231096431613}, {'label': 'LABEL_8', 'score': 0.00014558130351360887}, {'label': 'LABEL_9', 'score': 0.00014958814426790923}, {'label': 'LABEL_10', 'score': 0.00011634181282715872}, {'label': 'LABEL_11', 'score': 4.5097345719113946e-05}, {'label': 'LABEL_12', 'score': 8.068335591815412e-05}, {'label': 'LABEL_13', 'score': 4.724525933852419e-05}, {'label': 'LABEL_14', 'score': 0.00012563375639729202}, {'label': 'LABEL_15', 'score': 9.24634441616945e-05}, {'label': 'LABEL_16', 'score': 6.83424441376701e-05}]\n", + "6\n", + "[{'label': 'LABEL_0', 'score': 9.575629519531503e-05}, {'label': 'LABEL_1', 'score': 0.00013479188783094287}, {'label': 'LABEL_2', 'score': 6.24070453341119e-05}, {'label': 'LABEL_3', 'score': 9.491511445958167e-05}, {'label': 'LABEL_4', 'score': 0.00011898632510565221}, {'label': 'LABEL_5', 'score': 0.00012223367230035365}, {'label': 'LABEL_6', 'score': 0.9983828067779541}, {'label': 'LABEL_7', 'score': 0.00014901417307555676}, {'label': 'LABEL_8', 'score': 0.0001293729292228818}, {'label': 'LABEL_9', 'score': 0.00014636504056397825}, {'label': 'LABEL_10', 'score': 0.00011709715909091756}, {'label': 'LABEL_11', 'score': 4.3970183469355106e-05}, {'label': 'LABEL_12', 'score': 7.832375558791682e-05}, {'label': 'LABEL_13', 'score': 4.6482757170451805e-05}, {'label': 'LABEL_14', 'score': 0.00011872482718899846}, {'label': 'LABEL_15', 'score': 9.005393803818151e-05}, {'label': 'LABEL_16', 'score': 6.87053834553808e-05}]\n", + "6\n", + "[{'label': 'LABEL_0', 'score': 9.33124974835664e-05}, {'label': 'LABEL_1', 'score': 0.00012642868387047201}, {'label': 'LABEL_2', 'score': 6.495929847005755e-05}, {'label': 'LABEL_3', 'score': 9.773051715455949e-05}, {'label': 'LABEL_4', 'score': 0.00011607634951360524}, {'label': 'LABEL_5', 'score': 0.00012188677646918222}, {'label': 'LABEL_6', 'score': 0.9983865022659302}, {'label': 'LABEL_7', 'score': 0.0001447165122954175}, {'label': 'LABEL_8', 'score': 0.00012925465125590563}, {'label': 'LABEL_9', 'score': 0.0001489764981670305}, {'label': 'LABEL_10', 'score': 0.0001232580398209393}, {'label': 'LABEL_11', 'score': 4.4239117414690554e-05}, {'label': 'LABEL_12', 'score': 7.944685057736933e-05}, {'label': 'LABEL_13', 'score': 4.5822369429515675e-05}, {'label': 'LABEL_14', 'score': 0.00011649943189695477}, {'label': 'LABEL_15', 'score': 9.088807564694434e-05}, {'label': 'LABEL_16', 'score': 6.998340541031212e-05}]\n", + "6\n", + "[{'label': 'LABEL_0', 'score': 9.538340236758813e-05}, {'label': 'LABEL_1', 'score': 0.00013363973994273692}, {'label': 'LABEL_2', 'score': 6.720751116517931e-05}, {'label': 'LABEL_3', 'score': 0.00010068194387713447}, {'label': 'LABEL_4', 'score': 0.00011288334644632414}, {'label': 'LABEL_5', 'score': 0.00012565024371724576}, {'label': 'LABEL_6', 'score': 0.9983444213867188}, {'label': 'LABEL_7', 'score': 0.00015267464914359152}, {'label': 'LABEL_8', 'score': 0.00014014744374435395}, {'label': 'LABEL_9', 'score': 0.00014672863471787423}, {'label': 'LABEL_10', 'score': 0.0001220486665260978}, {'label': 'LABEL_11', 'score': 4.699776036432013e-05}, {'label': 'LABEL_12', 'score': 7.61943738325499e-05}, {'label': 'LABEL_13', 'score': 4.92853214382194e-05}, {'label': 'LABEL_14', 'score': 0.00012135148426750675}, {'label': 'LABEL_15', 'score': 9.276873606722802e-05}, {'label': 'LABEL_16', 'score': 7.18621740816161e-05}]\n", + "6\n" + ] + } + ], + "source": [ + "cpt = 0\n", + "for out in pipe(data()):\n", + " print(out)\n", + " # outputs a list of dicts like [[{'label': 'NEGATIVE', 'score': 0.0001223755971295759}, {'label': 'POSITIVE', 'score': 0.9998776316642761}]]\n", + " # proba de la class Géographie : 6\n", + " print(out[6]['label'][6:]) ### TODO modifier ici\n", + " cpt += 1\n", + " if cpt == 6:\n", + " break\n" + ] + }, + { + "cell_type": "code", + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -677,7 +688,7 @@ "'0'" ] }, - "execution_count": 19, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" }