From b8b339abc832b33c60cfb025981a44075ee3767e Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <moncla.ludovic@gmail.com>
Date: Mon, 28 Nov 2022 13:30:00 +0100
Subject: [PATCH] Update Predict_XAI.ipynb

---
 notebooks/Predict_XAI.ipynb | 41 +++++++++++++++----------------------
 1 file changed, 16 insertions(+), 25 deletions(-)

diff --git a/notebooks/Predict_XAI.ipynb b/notebooks/Predict_XAI.ipynb
index a75c4cc..1fdf013 100644
--- a/notebooks/Predict_XAI.ipynb
+++ b/notebooks/Predict_XAI.ipynb
@@ -26,8 +26,8 @@
       "outputs": [],
       "source": [
         "!pip install transformers==4.10.3\n",
-        "!pip install transformers_interpret\n",
-        "!pip install sentencepiece"
+        "!pip install sentencepiece\n",
+        "!pip install transformers_interpret"
       ]
     },
     {
@@ -93,22 +93,24 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 7,
+      "execution_count": 1,
       "metadata": {
         "id": "SkErnwgMMbRj"
       },
       "outputs": [],
       "source": [
-        "import os\n",
-        "import pandas as pd \n",
-        "import numpy as np\n",
         "import pickle \n",
         "import torch\n",
-        "from tqdm import tqdm\n",
         "\n",
-        "from transformers import BertTokenizer, BertForSequenceClassification, CamembertTokenizer, CamembertForSequenceClassification\n",
+        "from transformers import BertTokenizer, BertForSequenceClassification\n",
         "from transformers_interpret import SequenceClassificationExplainer\n",
-        "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler"
+        "\n",
+        "import numpy as np\n",
+        "import torch\n",
+        "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n",
+        "from tqdm import tqdm\n",
+        "import os\n",
+        "import pandas as pd \n"
       ]
     },
     {
@@ -160,7 +162,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 12,
+      "execution_count": 3,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -177,16 +179,10 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 3,
+      "execution_count": 4,
       "metadata": {},
       "outputs": [],
       "source": [
-        "def padding(content, max_len):\n",
-        "    if len(content) > max_len:\n",
-        "        content[:max_len]\n",
-        "    else:\n",
-        "        content + [0] * (max_len - len(content))\n",
-        "    return \n",
         "\n",
         "def generate_dataloader(tokenizer, sentences, batch_size = 8, max_len = 512):\n",
         "\n",
@@ -311,7 +307,7 @@
         "id": "c5QKcXulhNJ-"
       },
       "source": [
-        "## 3. Load Data\n",
+        "## 2. Load Data\n",
         "\n",
         "\n",
         "!! A modifier: charger le corpus parallele : EDdA et LGE"
@@ -541,7 +537,6 @@
         "#path = \"drive/MyDrive/Classification-EDdA/\"\n",
         "path = \"../\"\n",
         "model_name = \"bert-base-multilingual-cased\"\n",
-        "#model_name = \"camembert-base\"\n",
         "model_path = path + \"models/model_\" + model_name + \"_s10000.pt\""
       ]
     },
@@ -559,12 +554,8 @@
         }
       ],
       "source": [
-        "if model_name == 'bert-base-multilingual-cased' :\n",
-        "    print('Loading Bert Tokenizer...')\n",
-        "    tokenizer = BertTokenizer.from_pretrained(model_name)\n",
-        "elif model_name == 'camembert-base':\n",
-        "    print('Loading Camembert Tokenizer...')\n",
-        "    tokenizer = CamembertTokenizer.from_pretrained(model_name)"
+        "print('Loading Bert Tokenizer...')\n",
+        "tokenizer = BertTokenizer.from_pretrained(model_name)"
       ]
     },
     {
-- 
GitLab