diff --git a/notebooks/Predict_XAI.ipynb b/notebooks/Predict_XAI.ipynb index 517c7a3019ca40fd5cd0c7ec2d2f35338a460c37..694cd1bc35a809837a3780254ab77bd325dff079 100644 --- a/notebooks/Predict_XAI.ipynb +++ b/notebooks/Predict_XAI.ipynb @@ -755,7 +755,7 @@ "source": [ "## 4. Load model and predict\n", "\n", - "### 4.1 BERT" + "### 4.1 Load BERT model" ] }, { @@ -770,6 +770,17 @@ "model_path = path + \"models/model_\" + model_name + \"_s10000.pt\"" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "encoder_filename = \"models/label_encoder.pkl\"\n", + "with open(path + encoder_filename, 'rb') as file:\n", + " encoder = pickle.load(file)" + ] + }, { "cell_type": "code", "execution_count": 16, @@ -784,10 +795,25 @@ } ], "source": [ - "print('Loading Bert Tokenizer...')\n", "tokenizer = BertTokenizer.from_pretrained(model_name)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = BertForSequenceClassification.from_pretrained(model_path).to(gpu_name) #.to(\"cuda\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.2 Prepare datasets" + ] + }, { "cell_type": "code", "execution_count": 17, @@ -802,23 +828,35 @@ } ], "source": [ - "data_loader = generate_dataloader(tokenizer, data_LGE)" + "# LGE\n", + "data_loader_LGE = generate_dataloader(tokenizer, df_LGE.content.values)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "model = BertForSequenceClassification.from_pretrained(model_path).to(gpu_name) #.to(\"cuda\")" + "# LGE parallel\n", + "data_loader_LGE_par = generate_dataloader(tokenizer, df_LGE_par.content.values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# EDdA\n", + "data_loader_EDdA = generate_dataloader(tokenizer, df_EDdA.content.values)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### 4.2 Predict" + "### 4.3 Predict" ] }, { @@ -833,36 +871,28 @@ }, "outputs": [], "source": [ - "pred = predict(model, data_loader, device)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "encoder_filename = \"models/label_encoder.pkl\"\n", - "with open(path + encoder_filename, 'rb') as file:\n", - " encoder = pickle.load(file)" + "pred_LGE = predict(model, data_loader_LGE, device)\n", + "df_LGE['class_pred'] = list(encoder.inverse_transform(pred_LGE))" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "p2 = list(encoder.inverse_transform(pred))" + "pred_LGE_par = predict(model, data_loader_LGE_par, device)\n", + "df_LGE_par['class_pred'] = list(encoder.inverse_transform(pred_LGE_par))" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ - "df_LGE['domain'] = p2" + "pred_EDdA = predict(model, data_loader_EDdA, device)\n", + "df_EDdA['class_pred'] = list(encoder.inverse_transform(pred_EDdA))" ] }, { @@ -1569,7 +1599,7 @@ } ], "source": [ - "df_LGE.head(50)" + "df_LGE.head()" ] }, { @@ -1766,7 +1796,7 @@ } ], "source": [ - "content = \"Instrument de musique\" #df_LGE.content[2][:512]\n", + "content = \"Instrument de musique\" #df_LGE.content[2]\n", "word_attributions = cls_explainer(content if len(content) < 512 else content[:512])\n", "word_attributions" ]