diff --git a/notebooks/Predict.ipynb b/notebooks/Predict.ipynb index 441b5242a5056e9207e8fa2effaa7bbae5ead680..8fa3a5c6c593be9e5ade5bb38e29f1fee7582420 100644 --- a/notebooks/Predict.ipynb +++ b/notebooks/Predict.ipynb @@ -186,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "id": "M2awiee1r0zV" }, @@ -208,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -325,7 +325,7 @@ "4 A(Numis.). Dans la numismatique grecque, la le... 67.0 " ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -359,7 +359,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": { "id": "0qDZ86qTr0zX" }, @@ -374,7 +374,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": { "id": "KEljGX0br0zX" }, @@ -475,12 +475,16 @@ " pred_labels.append(pred_labels_i)\n", "\n", " pred_labels_ += [item for sublist in pred_labels for item in sublist]\n", - " return pred_labels_" + " return pred_labels_\n", + "\n", + "\n", + "\n", + "#https://discuss.huggingface.co/t/i-have-trained-my-classifier-now-how-do-i-do-predictions/3625/3\n" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -586,7 +590,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": { "id": "CN8EZst-r0zZ" }, @@ -596,7 +600,7 @@ "#model.load_state_dict(torch.load(model_path, map_location=torch.device('mps')))\n", "\n", "#model = BertForSequenceClassification.from_pretrained(model_path).to(\"cuda\")\n", - "model = BertForSequenceClassification.from_pretrained(model_path).to(\"mps\")" + "model = BertForSequenceClassification.from_pretrained(model_path).to(\"cpu\")" ] }, { @@ -610,6 +614,100 @@ "pred = predict(model, data_loader, device)" ] }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/pipelines/text_classification.py:89: UserWarning: `return_all_scores` is now deprecated, if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "[[{'label': 'LABEL_0', 'score': 9.58614400587976e-05},\n", + " {'label': 'LABEL_1', 'score': 0.00010365043272031471},\n", + " {'label': 'LABEL_2', 'score': 6.21283397777006e-05},\n", + " {'label': 'LABEL_3', 'score': 9.175329614663497e-05},\n", + " {'label': 'LABEL_4', 'score': 9.065424819709733e-05},\n", + " {'label': 'LABEL_5', 'score': 0.00010455227311467752},\n", + " {'label': 'LABEL_6', 'score': 0.9985577464103699},\n", + " {'label': 'LABEL_7', 'score': 0.00013558757200371474},\n", + " {'label': 'LABEL_8', 'score': 0.0001018877956084907},\n", + " {'label': 'LABEL_9', 'score': 0.0001431443088222295},\n", + " {'label': 'LABEL_10', 'score': 0.00010823880438692868},\n", + " {'label': 'LABEL_11', 'score': 3.7985137169016525e-05},\n", + " {'label': 'LABEL_12', 'score': 6.803833093727008e-05},\n", + " {'label': 'LABEL_13', 'score': 4.024818554171361e-05},\n", + " {'label': 'LABEL_14', 'score': 0.0001047810583258979},\n", + " {'label': 'LABEL_15', 'score': 8.337549661519006e-05},\n", + " {'label': 'LABEL_16', 'score': 7.031656423350796e-05}]]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "## TEST\n", + "\n", + "\n", + "from transformers import TextClassificationPipeline\n", + "\n", + "pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)\n", + "# outputs a list of dicts like [[{'label': 'NEGATIVE', 'score': 0.0001223755971295759}, {'label': 'POSITIVE', 'score': 0.9998776316642761}]]\n", + "prob = pipe(\"Lyon, petite ville de France, dans la région Rhone-Alpes.\")\n", + "prob" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'0'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prob[0][0]['label'][6:]" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['Agriculture'], dtype=object)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "## TEST\n", + "\n", + "encoder.inverse_transform([int(prob[0][0]['label'][6:])])" + ] + }, { "cell_type": "code", "execution_count": 13, @@ -1638,7 +1736,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 12, "metadata": { "id": "fo6k4li1r0za" }, @@ -2200,7 +2298,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.15" + "version": "3.9.13" }, "vscode": { "interpreter": {