diff --git a/notebooks/Predict_XAI.ipynb b/notebooks/Predict_XAI.ipynb index a75c4cc1fc3f4a90220ee062b48a2a78e80f2cb0..1fdf013533fbed1dc7eeddf6fb6e7160b9225fdd 100644 --- a/notebooks/Predict_XAI.ipynb +++ b/notebooks/Predict_XAI.ipynb @@ -26,8 +26,8 @@ "outputs": [], "source": [ "!pip install transformers==4.10.3\n", - "!pip install transformers_interpret\n", - "!pip install sentencepiece" + "!pip install sentencepiece\n", + "!pip install transformers_interpret" ] }, { @@ -93,22 +93,24 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 1, "metadata": { "id": "SkErnwgMMbRj" }, "outputs": [], "source": [ - "import os\n", - "import pandas as pd \n", - "import numpy as np\n", "import pickle \n", "import torch\n", - "from tqdm import tqdm\n", "\n", - "from transformers import BertTokenizer, BertForSequenceClassification, CamembertTokenizer, CamembertForSequenceClassification\n", + "from transformers import BertTokenizer, BertForSequenceClassification\n", "from transformers_interpret import SequenceClassificationExplainer\n", - "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler" + "\n", + "import numpy as np\n", + "import torch\n", + "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n", + "from tqdm import tqdm\n", + "import os\n", + "import pandas as pd \n" ] }, { @@ -160,7 +162,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -177,16 +179,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "def padding(content, max_len):\n", - " if len(content) > max_len:\n", - " content[:max_len]\n", - " else:\n", - " content + [0] * (max_len - len(content))\n", - " return \n", "\n", "def generate_dataloader(tokenizer, sentences, batch_size = 8, max_len = 512):\n", "\n", @@ -311,7 +307,7 @@ "id": "c5QKcXulhNJ-" }, "source": [ - "## 3. Load Data\n", + "## 2. Load Data\n", "\n", "\n", "!! A modifier: charger le corpus parallele : EDdA et LGE" @@ -541,7 +537,6 @@ "#path = \"drive/MyDrive/Classification-EDdA/\"\n", "path = \"../\"\n", "model_name = \"bert-base-multilingual-cased\"\n", - "#model_name = \"camembert-base\"\n", "model_path = path + \"models/model_\" + model_name + \"_s10000.pt\"" ] }, @@ -559,12 +554,8 @@ } ], "source": [ - "if model_name == 'bert-base-multilingual-cased' :\n", - " print('Loading Bert Tokenizer...')\n", - " tokenizer = BertTokenizer.from_pretrained(model_name)\n", - "elif model_name == 'camembert-base':\n", - " print('Loading Camembert Tokenizer...')\n", - " tokenizer = CamembertTokenizer.from_pretrained(model_name)" + "print('Loading Bert Tokenizer...')\n", + "tokenizer = BertTokenizer.from_pretrained(model_name)" ] }, {