From 53b9ca9f81db570e0b843b4fd1adf930129af099 Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <moncla.ludovic@gmail.com>
Date: Sun, 10 Jul 2022 18:38:40 +0200
Subject: [PATCH] Create Classification_BiLSTM.ipynb

---
 notebooks/Classification_BiLSTM.ipynb | 1893 +++++++++++++++++++++++++
 1 file changed, 1893 insertions(+)
 create mode 100644 notebooks/Classification_BiLSTM.ipynb

diff --git a/notebooks/Classification_BiLSTM.ipynb b/notebooks/Classification_BiLSTM.ipynb
new file mode 100644
index 0000000..b10a1ad
--- /dev/null
+++ b/notebooks/Classification_BiLSTM.ipynb
@@ -0,0 +1,1893 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "EDdA-Classification_LSTM.ipynb",
+      "provenance": [],
+      "collapsed_sections": []
+    },
+    "kernelspec": {
+      "display_name": "Python 3",
+      "name": "python3"
+    },
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "0yFsoHXX8Iyy"
+      },
+      "source": [
+        "# Deep learning for EDdA classification"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "tFlUCDL2778i"
+      },
+      "source": [
+        "## Setup colab environment"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "Sp8d_Uus7SHJ",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "f0dbf6d3-0d66-46b9-c00b-6fd788f4c689"
+      },
+      "source": [
+        "from google.colab import drive\n",
+        "drive.mount('/content/drive')"
+      ],
+      "execution_count": 1,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Mounted at /content/drive\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "jQBu-p6hBU-j"
+      },
+      "source": [
+        "### Install packages"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "bTIXsF6kBUdh"
+      },
+      "source": [
+        "#!pip install zeugma\n",
+        "#!pip install plot_model"
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "56-04SNF8BMx"
+      },
+      "source": [
+        "### Import librairies"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "HwWkSznz7SEv"
+      },
+      "source": [
+        "import pandas as pd\n",
+        "import numpy as np\n",
+        "import matplotlib.pyplot as plt\n",
+        "import pickle\n",
+        "import os\n",
+        "\n",
+        "from tqdm import tqdm\n",
+        "import requests, zipfile, io\n",
+        "import codecs\n",
+        "\n",
+        "from sklearn import preprocessing # LabelEncoder\n",
+        "from sklearn.metrics import classification_report\n",
+        "from sklearn.metrics import confusion_matrix\n",
+        "\n",
+        "from keras.preprocessing import sequence\n",
+        "from keras.preprocessing.text import Tokenizer\n",
+        "\n",
+        "from keras.layers import BatchNormalization, Input, Reshape, Conv1D, MaxPool1D, Conv2D, MaxPool2D, Concatenate\n",
+        "from keras.layers import Embedding, Dropout, Flatten, Dense, LSTM, Bidirectional\n",
+        "from keras.models import Model, load_model\n",
+        "from keras.callbacks import ModelCheckpoint\n"
+      ],
+      "execution_count": 2,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "xrekV6W978l4"
+      },
+      "source": [
+        "### Utils functions"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "4LJ5blQR7PUe"
+      },
+      "source": [
+        "\n",
+        "def resample_classes(df, classColumnName, numberOfInstances):\n",
+        "  #random numberOfInstances elements\n",
+        "  replace = False  # with replacement\n",
+        "  fn = lambda obj: obj.loc[np.random.choice(obj.index, numberOfInstances if len(obj) > numberOfInstances else len(obj), replace),:]\n",
+        "  return df.groupby(classColumnName, as_index=False).apply(fn)\n",
+        "    \n"
+      ],
+      "execution_count": 3,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "-Rh3JMDh7zYd"
+      },
+      "source": [
+        ""
+      ],
+      "execution_count": 3,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "MtLr35eM753e"
+      },
+      "source": [
+        "## Load Data"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "FnbNT4NF7zal",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "300e8fb1-43a5-4f25-9afa-15e04730b0f4"
+      },
+      "source": [
+        "!wget https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/training_set.tsv\n",
+        "!wget https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/test_set.tsv"
+      ],
+      "execution_count": 4,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "--2022-02-18 07:32:19--  https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/training_set.tsv\n",
+            "Resolving projet.liris.cnrs.fr (projet.liris.cnrs.fr)... 134.214.142.28\n",
+            "Connecting to projet.liris.cnrs.fr (projet.liris.cnrs.fr)|134.214.142.28|:443... connected.\n",
+            "HTTP request sent, awaiting response... 200 OK\n",
+            "Length: 175634219 (167M) [text/tab-separated-values]\n",
+            "Saving to: ‘training_set.tsv’\n",
+            "\n",
+            "training_set.tsv    100%[===================>] 167.50M  28.1MB/s    in 6.6s    \n",
+            "\n",
+            "2022-02-18 07:32:26 (25.5 MB/s) - ‘training_set.tsv’ saved [175634219/175634219]\n",
+            "\n",
+            "--2022-02-18 07:32:26--  https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/test_set.tsv\n",
+            "Resolving projet.liris.cnrs.fr (projet.liris.cnrs.fr)... 134.214.142.28\n",
+            "Connecting to projet.liris.cnrs.fr (projet.liris.cnrs.fr)|134.214.142.28|:443... connected.\n",
+            "HTTP request sent, awaiting response... 200 OK\n",
+            "Length: 42730598 (41M) [text/tab-separated-values]\n",
+            "Saving to: ‘test_set.tsv’\n",
+            "\n",
+            "test_set.tsv        100%[===================>]  40.75M  19.6MB/s    in 2.1s    \n",
+            "\n",
+            "2022-02-18 07:32:29 (19.6 MB/s) - ‘test_set.tsv’ saved [42730598/42730598]\n",
+            "\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Loading dataset"
+      ],
+      "metadata": {
+        "id": "UHushJ1XfUj9"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "train_path = 'training_set.tsv'\n",
+        "test_path =  'test_set.tsv'"
+      ],
+      "metadata": {
+        "id": "Q4te2c0bfvaJ"
+      },
+      "execution_count": 5,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "nRLaQUO97zcq"
+      },
+      "source": [
+        "df_train = pd.read_csv(train_path, sep=\"\\t\")\n",
+        "\n",
+        "#df_train = resample_classes(df_train, columnClass, maxOfInstancePerClass)\n"
+      ],
+      "execution_count": 6,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "df_train.sample(5)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 600
+        },
+        "id": "2MvHEc7zVK1N",
+        "outputId": "57e4ff59-b74f-49ff-f441-5b4fba0973cf"
+      },
+      "execution_count": 7,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/html": [
+              "\n",
+              "  <div id=\"df-00b9a714-8d6b-461c-83fb-70e844fc260e\">\n",
+              "    <div class=\"colab-df-container\">\n",
+              "      <div>\n",
+              "<style scoped>\n",
+              "    .dataframe tbody tr th:only-of-type {\n",
+              "        vertical-align: middle;\n",
+              "    }\n",
+              "\n",
+              "    .dataframe tbody tr th {\n",
+              "        vertical-align: top;\n",
+              "    }\n",
+              "\n",
+              "    .dataframe thead th {\n",
+              "        text-align: right;\n",
+              "    }\n",
+              "</style>\n",
+              "<table border=\"1\" class=\"dataframe\">\n",
+              "  <thead>\n",
+              "    <tr style=\"text-align: right;\">\n",
+              "      <th></th>\n",
+              "      <th>volume</th>\n",
+              "      <th>numero</th>\n",
+              "      <th>head</th>\n",
+              "      <th>normClass</th>\n",
+              "      <th>classEDdA</th>\n",
+              "      <th>author</th>\n",
+              "      <th>id_enccre</th>\n",
+              "      <th>domaine_enccre</th>\n",
+              "      <th>ensemble_domaine_enccre</th>\n",
+              "      <th>content</th>\n",
+              "      <th>contentWithoutClass</th>\n",
+              "      <th>firstParagraph</th>\n",
+              "      <th>nb_words</th>\n",
+              "    </tr>\n",
+              "  </thead>\n",
+              "  <tbody>\n",
+              "    <tr>\n",
+              "      <th>31223</th>\n",
+              "      <td>1</td>\n",
+              "      <td>1494</td>\n",
+              "      <td>AJUSTOIRE</td>\n",
+              "      <td>La Monnoie</td>\n",
+              "      <td>a la Monnoie.</td>\n",
+              "      <td>unsigned</td>\n",
+              "      <td>v1-1010-0</td>\n",
+              "      <td>monnaie</td>\n",
+              "      <td>Monnaie</td>\n",
+              "      <td>AJUSTOIRE, s. m. (à la Monnoie.) est une balan...</td>\n",
+              "      <td>ajustoire s. m. monnoie balance \\n sert ajuste...</td>\n",
+              "      <td>ajustoire s. m. monnoie balance \\n sert ajuste...</td>\n",
+              "      <td>85</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>710</th>\n",
+              "      <td>14</td>\n",
+              "      <td>719</td>\n",
+              "      <td>REPAITRIR</td>\n",
+              "      <td>Grammaire</td>\n",
+              "      <td>Gram.</td>\n",
+              "      <td>unsigned</td>\n",
+              "      <td>v14-375-0</td>\n",
+              "      <td>grammaire</td>\n",
+              "      <td>Grammaire</td>\n",
+              "      <td>REPAITRIR, v. act. (Gram.) paîtrir de-rechef.\\...</td>\n",
+              "      <td>repaitrir vers act   paîtrir de-rechef \\n arti...</td>\n",
+              "      <td>repaitrir vers act   paîtrir de-rechef \\n arti...</td>\n",
+              "      <td>19</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>4668</th>\n",
+              "      <td>4</td>\n",
+              "      <td>1838</td>\n",
+              "      <td>Courir</td>\n",
+              "      <td>Géographie</td>\n",
+              "      <td>en Géographie</td>\n",
+              "      <td>d'Alembert</td>\n",
+              "      <td>v4-834-5</td>\n",
+              "      <td>géographie</td>\n",
+              "      <td>Géographie</td>\n",
+              "      <td>Courir, se dit aussi en Géographie. Cette suit...</td>\n",
+              "      <td>courir suite \\n montagne -on court est-ouest \\...</td>\n",
+              "      <td>courir suite \\n montagne -on court est-ouest \\...</td>\n",
+              "      <td>68</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>25846</th>\n",
+              "      <td>1</td>\n",
+              "      <td>2789</td>\n",
+              "      <td>Anneaux</td>\n",
+              "      <td>Manufacture en soie</td>\n",
+              "      <td>manufactures en soie</td>\n",
+              "      <td>unsigned</td>\n",
+              "      <td>v1-2037-10</td>\n",
+              "      <td>manufacture</td>\n",
+              "      <td>Arts et métiers</td>\n",
+              "      <td>Anneaux, s. m. pl. ce sont dans les manufactur...</td>\n",
+              "      <td>anneau s. m. pl manufacture \\n soie petit cerc...</td>\n",
+              "      <td>anneau s. m. pl manufacture \\n soie petit cerc...</td>\n",
+              "      <td>214</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>21034</th>\n",
+              "      <td>1</td>\n",
+              "      <td>2118</td>\n",
+              "      <td>AMBLEUR</td>\n",
+              "      <td>Manège</td>\n",
+              "      <td>Manege</td>\n",
+              "      <td>Eidous</td>\n",
+              "      <td>v1-1516-0</td>\n",
+              "      <td>manège</td>\n",
+              "      <td>Maréchage - Manège</td>\n",
+              "      <td>AMBLEUR, s. m. (Manege.) Officier de la grande...</td>\n",
+              "      <td>ambleur s. m. officier grand \\n petit écurie r...</td>\n",
+              "      <td>ambleur s. m. officier grand \\n petit écurie r...</td>\n",
+              "      <td>25</td>\n",
+              "    </tr>\n",
+              "  </tbody>\n",
+              "</table>\n",
+              "</div>\n",
+              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-00b9a714-8d6b-461c-83fb-70e844fc260e')\"\n",
+              "              title=\"Convert this dataframe to an interactive table.\"\n",
+              "              style=\"display:none;\">\n",
+              "        \n",
+              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
+              "       width=\"24px\">\n",
+              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
+              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
+              "  </svg>\n",
+              "      </button>\n",
+              "      \n",
+              "  <style>\n",
+              "    .colab-df-container {\n",
+              "      display:flex;\n",
+              "      flex-wrap:wrap;\n",
+              "      gap: 12px;\n",
+              "    }\n",
+              "\n",
+              "    .colab-df-convert {\n",
+              "      background-color: #E8F0FE;\n",
+              "      border: none;\n",
+              "      border-radius: 50%;\n",
+              "      cursor: pointer;\n",
+              "      display: none;\n",
+              "      fill: #1967D2;\n",
+              "      height: 32px;\n",
+              "      padding: 0 0 0 0;\n",
+              "      width: 32px;\n",
+              "    }\n",
+              "\n",
+              "    .colab-df-convert:hover {\n",
+              "      background-color: #E2EBFA;\n",
+              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
+              "      fill: #174EA6;\n",
+              "    }\n",
+              "\n",
+              "    [theme=dark] .colab-df-convert {\n",
+              "      background-color: #3B4455;\n",
+              "      fill: #D2E3FC;\n",
+              "    }\n",
+              "\n",
+              "    [theme=dark] .colab-df-convert:hover {\n",
+              "      background-color: #434B5C;\n",
+              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
+              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
+              "      fill: #FFFFFF;\n",
+              "    }\n",
+              "  </style>\n",
+              "\n",
+              "      <script>\n",
+              "        const buttonEl =\n",
+              "          document.querySelector('#df-00b9a714-8d6b-461c-83fb-70e844fc260e button.colab-df-convert');\n",
+              "        buttonEl.style.display =\n",
+              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
+              "\n",
+              "        async function convertToInteractive(key) {\n",
+              "          const element = document.querySelector('#df-00b9a714-8d6b-461c-83fb-70e844fc260e');\n",
+              "          const dataTable =\n",
+              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
+              "                                                     [key], {});\n",
+              "          if (!dataTable) return;\n",
+              "\n",
+              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
+              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
+              "            + ' to learn more about interactive tables.';\n",
+              "          element.innerHTML = '';\n",
+              "          dataTable['output_type'] = 'display_data';\n",
+              "          await google.colab.output.renderOutput(dataTable, element);\n",
+              "          const docLink = document.createElement('div');\n",
+              "          docLink.innerHTML = docLinkHtml;\n",
+              "          element.appendChild(docLink);\n",
+              "        }\n",
+              "      </script>\n",
+              "    </div>\n",
+              "  </div>\n",
+              "  "
+            ],
+            "text/plain": [
+              "       volume  ...  nb_words\n",
+              "31223       1  ...        85\n",
+              "710        14  ...        19\n",
+              "4668        4  ...        68\n",
+              "25846       1  ...       214\n",
+              "21034       1  ...        25\n",
+              "\n",
+              "[5 rows x 13 columns]"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 7
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Configuration\n"
+      ],
+      "metadata": {
+        "id": "-63bh_cKfN4p"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "columnText = 'contentWithoutClass'\n",
+        "columnClass = 'ensemble_domaine_enccre'\n",
+        "\n",
+        "minOfInstancePerClass = 0\n",
+        "maxOfInstancePerClass = 1500\n",
+        "\n",
+        "batch_size = 64\n",
+        "validation_split = 0.20\n",
+        "#max_len = 512 # \n",
+        "max_nb_words = 20000      # taille du vocabulaire\n",
+        "max_sequence_length = 512  # taille max du 'document' \n",
+        "epochs = 10\n",
+        "\n",
+        "#embedding_name = \"fasttext\" \n",
+        "#embedding_dim = 300 \n",
+        "\n",
+        "embedding_name = \"glove.6B.100d\"\n",
+        "embedding_dim = 100 \n",
+        "\n",
+        "path = \"drive/MyDrive/Classification-EDdA/\"\n",
+        "encoder_filename = \"label_encoder.pkl\"\n",
+        "tokenizer_filename = \"tokenizer_keras.pkl\""
+      ],
+      "metadata": {
+        "id": "nsRuyzYUfOBg"
+      },
+      "execution_count": 8,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Preprocessing\n"
+      ],
+      "metadata": {
+        "id": "ZDz-Y1LCfQt0"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "if maxOfInstancePerClass != 10000:\n",
+        "  df_train = resample_classes(df_train, columnClass, maxOfInstancePerClass)"
+      ],
+      "metadata": {
+        "id": "vw7UlS6-v0sN"
+      },
+      "execution_count": 9,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "vGWAgBH87ze8"
+      },
+      "source": [
+        "labels  = df_train[columnClass]\n",
+        "numberOfClasses = labels.nunique()\n",
+        "\n",
+        "\n",
+        "if os.path.isfile(path+encoder_filename):    \n",
+        "    # load existing encoder \n",
+        "    with open(path+encoder_filename, 'rb') as file:\n",
+        "      encoder = pickle.load(file)\n",
+        "\n",
+        "else:\n",
+        "  encoder = preprocessing.LabelEncoder()\n",
+        "  encoder.fit(labels)\n",
+        "\n",
+        "  with open(path+encoder_filename, 'wb') as file:\n",
+        "      pickle.dump(encoder, file)\n",
+        "\n",
+        "\n",
+        "labels = encoder.transform(labels)"
+      ],
+      "execution_count": 10,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "encoder.classes_"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "SME4vvVhW9Sn",
+        "outputId": "00feb4ff-88b0-49b7-97ca-94219632371a"
+      },
+      "execution_count": 11,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "array(['Agriculture - Economie rustique', 'Anatomie', 'Antiquité',\n",
+              "       'Architecture', 'Arts et métiers', 'Beaux-arts',\n",
+              "       'Belles-lettres - Poésie', 'Blason', 'Caractères', 'Chasse',\n",
+              "       'Chimie', 'Commerce', 'Droit - Jurisprudence',\n",
+              "       'Economie domestique', 'Grammaire', 'Géographie', 'Histoire',\n",
+              "       'Histoire naturelle', 'Jeu', 'Marine', 'Maréchage - Manège',\n",
+              "       'Mathématiques', 'Mesure', 'Militaire (Art) - Guerre - Arme',\n",
+              "       'Minéralogie', 'Monnaie', 'Musique', 'Médailles',\n",
+              "       'Médecine - Chirurgie', 'Métiers', 'Pharmacie', 'Philosophie',\n",
+              "       'Physique - [Sciences physico-mathématiques]', 'Politique',\n",
+              "       'Pêche', 'Religion', 'Spectacle', 'Superstition'], dtype=object)"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 11
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "labels_index = dict(zip(list(encoder.classes_), encoder.transform(list(encoder.classes_))))"
+      ],
+      "metadata": {
+        "id": "nIzWQ2VbW_UO"
+      },
+      "execution_count": 12,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "labels_index"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "4e7ggEGiXC_W",
+        "outputId": "e35c232e-cc24-47e7-aa26-0e5af7a570b4"
+      },
+      "execution_count": 13,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "{'Agriculture - Economie rustique': 0,\n",
+              " 'Anatomie': 1,\n",
+              " 'Antiquité': 2,\n",
+              " 'Architecture': 3,\n",
+              " 'Arts et métiers': 4,\n",
+              " 'Beaux-arts': 5,\n",
+              " 'Belles-lettres - Poésie': 6,\n",
+              " 'Blason': 7,\n",
+              " 'Caractères': 8,\n",
+              " 'Chasse': 9,\n",
+              " 'Chimie': 10,\n",
+              " 'Commerce': 11,\n",
+              " 'Droit - Jurisprudence': 12,\n",
+              " 'Economie domestique': 13,\n",
+              " 'Grammaire': 14,\n",
+              " 'Géographie': 15,\n",
+              " 'Histoire': 16,\n",
+              " 'Histoire naturelle': 17,\n",
+              " 'Jeu': 18,\n",
+              " 'Marine': 19,\n",
+              " 'Maréchage - Manège': 20,\n",
+              " 'Mathématiques': 21,\n",
+              " 'Mesure': 22,\n",
+              " 'Militaire (Art) - Guerre - Arme': 23,\n",
+              " 'Minéralogie': 24,\n",
+              " 'Monnaie': 25,\n",
+              " 'Musique': 26,\n",
+              " 'Médailles': 27,\n",
+              " 'Médecine - Chirurgie': 28,\n",
+              " 'Métiers': 29,\n",
+              " 'Pharmacie': 30,\n",
+              " 'Philosophie': 31,\n",
+              " 'Physique - [Sciences physico-mathématiques]': 32,\n",
+              " 'Politique': 33,\n",
+              " 'Pêche': 34,\n",
+              " 'Religion': 35,\n",
+              " 'Spectacle': 36,\n",
+              " 'Superstition': 37}"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 13
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Loading pre-trained embeddings\n",
+        "\n",
+        "#### FastText"
+      ],
+      "metadata": {
+        "id": "pguj5vY4vUnj"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# download FastText (prend trop de place pour le laisser sur le drive)\n",
+        "zip_file_url = \"https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip\"\n",
+        "r = requests.get(zip_file_url)\n",
+        "z = zipfile.ZipFile(io.BytesIO(r.content))\n",
+        "z.extractall()"
+      ],
+      "metadata": {
+        "id": "p25AssV4vUuj"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "print('loading word embeddings FastText...')\n",
+        "\n",
+        "embeddings_index = {}\n",
+        "f = codecs.open('crawl-300d-2M.vec', encoding='utf-8')\n",
+        "\n",
+        "for line in tqdm(f):\n",
+        "    values = line.rstrip().rsplit(' ')\n",
+        "    word = values[0]\n",
+        "    coefs = np.asarray(values[1:], dtype='float32')\n",
+        "    embeddings_index[word] = coefs\n",
+        "f.close()\n",
+        "\n",
+        "print('found %s word vectors' % len(embeddings_index))"
+      ],
+      "metadata": {
+        "id": "MYAljz1jvUxH"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "#### GLOVE"
+      ],
+      "metadata": {
+        "id": "4KdmWdnhvU6A"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# download Glove\n",
+        "#zip_file_url = \"https://nlp.stanford.edu/data/glove.6B.zip\"\n",
+        "#r = requests.get(zip_file_url)\n",
+        "#z = zipfile.ZipFile(io.BytesIO(r.content))\n",
+        "#z.extractall()"
+      ],
+      "metadata": {
+        "id": "38rp4FSnvVBE"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "print('loading word embeddings GLOVE...')\n",
+        "\n",
+        "embeddings_index = {}\n",
+        "f = open(path+\"embeddings/\"+embedding_name+\".txt\", encoding='utf-8')\n",
+        "for line in tqdm(f):\n",
+        "    values = line.split()\n",
+        "    word = values[0]\n",
+        "    coefs = np.asarray(values[1:], dtype='float32')\n",
+        "    embeddings_index[word] = coefs\n",
+        "f.close()\n",
+        "\n",
+        "print('Found %s word vectors.' % len(embeddings_index))"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "RVvIk8HdvVDZ",
+        "outputId": "5ca6c955-1764-42e9-c91c-4fab3953312d"
+      },
+      "execution_count": 14,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "loading word embeddings GLOVE...\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "400000it [00:19, 20293.28it/s]"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Found 400000 word vectors.\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "HuUVfklf-dSR"
+      },
+      "source": [
+        "## Training models"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "NTNh6kMTp_eU",
+        "outputId": "19b9945e-36ee-4ee7-ae45-43eb22122b93"
+      },
+      "source": [
+        "\n",
+        "raw_docs_train = df_train[columnText].tolist()\n",
+        "\n",
+        "\n",
+        "print(\"pre-processing train data...\")\n",
+        "\n",
+        "if os.path.isfile(path+tokenizer_filename):\n",
+        "  with open(path+tokenizer_filename, 'rb') as file:\n",
+        "      tokenizer = pickle.load(file)\n",
+        "else:\n",
+        "  tokenizer = Tokenizer(num_words = max_nb_words)\n",
+        "  tokenizer.fit_on_texts(raw_docs_train) \n",
+        "\n",
+        "  with open(path+tokenizer_filename, 'wb') as file:\n",
+        "        pickle.dump(tokenizer, file)\n",
+        "\n",
+        "sequences = tokenizer.texts_to_sequences(raw_docs_train)\n",
+        "\n",
+        "word_index = tokenizer.word_index\n",
+        "print(\"dictionary size: \", len(word_index))\n",
+        "\n",
+        "#pad sequences\n",
+        "data = sequence.pad_sequences(sequences, maxlen=max_sequence_length)\n",
+        "\n",
+        "print('Shape of data tensor:', data.shape)\n",
+        "print('Shape of label tensor:', labels.shape)\n",
+        "print(labels)"
+      ],
+      "execution_count": 15,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "pre-processing train data...\n",
+            "dictionary size:  190508\n",
+            "Shape of data tensor: (27381, 512)\n",
+            "Shape of label tensor: (27381,)\n",
+            "[ 0  0  0 ... 37 37 37]\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "data"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "FtmB9Vgcaxgl",
+        "outputId": "4165d4d1-a778-4831-9342-9e48f62763bd"
+      },
+      "execution_count": 16,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "array([[    0,     0,     0, ...,   147,  7021,  7021],\n",
+              "       [    0,     0,     0, ...,   613,  9556,   861],\n",
+              "       [    0,     0,     0, ...,  3733,  4758,    29],\n",
+              "       ...,\n",
+              "       [    0,     0,     0, ...,   980,  3069,   347],\n",
+              "       [    0,     0,     0, ...,  2341,   231, 15988],\n",
+              "       [    0,     0,     0, ...,  4626,    27,    22]], dtype=int32)"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 16
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# split the data into a training set and a validation set\n",
+        "\n",
+        "indices = np.arange(data.shape[0])\n",
+        "np.random.shuffle(indices)\n",
+        "data = data[indices]\n",
+        "labels = labels[indices]\n",
+        "\n",
+        "nb_validation_samples = int(validation_split * data.shape[0])\n",
+        "\n",
+        "x_train = data[:-nb_validation_samples]\n",
+        "y_train = labels[:-nb_validation_samples]\n",
+        "x_val = data[-nb_validation_samples:]\n",
+        "y_val = labels[-nb_validation_samples:]\n"
+      ],
+      "metadata": {
+        "id": "sHYJ4P-YDfFb"
+      },
+      "execution_count": 17,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "wGjQI0YgpQAS",
+        "outputId": "e155402c-f124-4619-fe43-49c4b5865df5"
+      },
+      "source": [
+        "#embedding matrix\n",
+        "\n",
+        "print('preparing embedding matrix...')\n",
+        "\n",
+        "embedding_matrix = np.zeros((len(word_index)+1, embedding_dim))\n",
+        "\n",
+        "for word, i in word_index.items():\n",
+        "    embedding_vector = embeddings_index.get(word)\n",
+        "    if embedding_vector is not None : \n",
+        "        embedding_matrix[i] = embedding_vector\n"
+      ],
+      "execution_count": 18,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "preparing embedding matrix...\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "\n",
+        "#filter_sizes = [2,  3, 5]\n",
+        "drop = 0.2\n",
+        "\n",
+        "embedding_layer = Embedding(len(word_index)+1, embedding_dim, input_length = max_sequence_length,\n",
+        "                    weights=[embedding_matrix], trainable=False)\n",
+        "inputs = Input(shape=(max_sequence_length), dtype='int32')\n",
+        "embedding = embedding_layer(inputs)\n",
+        "\n",
+        "print(embedding.shape)\n",
+        "#reshape = Reshape((max_sequence_length, embedding_dim, 1))(embedding)\n",
+        "#print(reshape.shape)\n",
+        "\n",
+        "\n",
+        "#model = Sequential()\n",
+        "#model.add(Embedding(MAX_NB_WORDS, EMBEDDING_DIM, input_length=X.shape[1]))\n",
+        "#model.add(SpatialDropout1D(0.2))\n",
+        "#model.add(LSTM(100, dropout=0.2, recurrent_dropout=0.2))\n",
+        "#model.add(Dense(13, activation='softmax'))\n",
+        "#model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
+        "\n",
+        "\n",
+        "\n",
+        "#conv_0 = Conv1D(64, 5, activation='relu')(lstm)\n",
+        "#conv_1 = Conv1D(128, 5, activation='relu')(embedding)\n",
+        "#conv_2 = Conv1D(128, 5, activation='relu')(embedding)\n",
+        "\n",
+        "#maxpool_0 = MaxPool1D(pool_size=4)(conv_0)\n",
+        "\n",
+        "lstm = Bidirectional(LSTM(100))(embedding)\n",
+        "\n",
+        "#maxpool_1 = MaxPool1D(5)(conv_1)\n",
+        "#maxpool_2 = MaxPool1D(35)(conv_2)\n",
+        "\n",
+        "#concatenated_tensor = Concatenate(axis=1)([maxpool_0, maxpool_1, maxpool_2])\n",
+        "#flatten = Flatten()(maxpool_0)\n",
+        "dropout = Dropout(drop)(lstm)\n",
+        "output = Dense(len(labels_index), activation='softmax')(dropout)\n",
+        "\n",
+        "# this creates a model that includes\n",
+        "model = Model(inputs=inputs, outputs=output)\n",
+        "\n",
+        "checkpoint = ModelCheckpoint('weights_lstm_sentece.hdf5', monitor='val_acc', verbose=1, save_best_only=True, mode='auto')\n",
+        "#adam = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)\n",
+        "\n",
+        "model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])\n",
+        "model.summary()"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "OUphqlYJCC9n",
+        "outputId": "3fc0fc31-321c-46dd-a7c7-ff046bc5c819"
+      },
+      "execution_count": 19,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "(None, 512, 100)\n",
+            "Model: \"model\"\n",
+            "_________________________________________________________________\n",
+            " Layer (type)                Output Shape              Param #   \n",
+            "=================================================================\n",
+            " input_1 (InputLayer)        [(None, 512)]             0         \n",
+            "                                                                 \n",
+            " embedding (Embedding)       (None, 512, 100)          19050900  \n",
+            "                                                                 \n",
+            " bidirectional (Bidirectiona  (None, 200)              160800    \n",
+            " l)                                                              \n",
+            "                                                                 \n",
+            " dropout (Dropout)           (None, 200)               0         \n",
+            "                                                                 \n",
+            " dense (Dense)               (None, 38)                7638      \n",
+            "                                                                 \n",
+            "=================================================================\n",
+            "Total params: 19,219,338\n",
+            "Trainable params: 168,438\n",
+            "Non-trainable params: 19,050,900\n",
+            "_________________________________________________________________\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "history = model.fit(x_train, y_train, \n",
+        "                    batch_size=batch_size, \n",
+        "                    epochs=epochs, \n",
+        "                    verbose=1,\n",
+        "                    callbacks=[checkpoint],\n",
+        "                    validation_data=(x_val, y_val))"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "3aUBdLdNCEGK",
+        "outputId": "20203bf0-98ae-4a56-86cb-76dd8a36ac9f"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Epoch 1/10\n",
+            "343/343 [==============================] - ETA: 0s - loss: 2.7653 - acc: 0.2694\n",
+            "Epoch 1: val_acc improved from -inf to 0.38386, saving model to weights_lstm_sentece.hdf5\n",
+            "343/343 [==============================] - 388s 1s/step - loss: 2.7653 - acc: 0.2694 - val_loss: 2.3262 - val_acc: 0.3839\n",
+            "Epoch 2/10\n",
+            "343/343 [==============================] - ETA: 0s - loss: 2.2038 - acc: 0.4024\n",
+            "Epoch 2: val_acc improved from 0.38386 to 0.43006, saving model to weights_lstm_sentece.hdf5\n",
+            "343/343 [==============================] - 400s 1s/step - loss: 2.2038 - acc: 0.4024 - val_loss: 2.0739 - val_acc: 0.4301\n",
+            "Epoch 3/10\n",
+            "343/343 [==============================] - ETA: 0s - loss: 1.9875 - acc: 0.4516"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "ZcYbQsQEJKLq"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "\n",
+        "plt.plot(history.history['acc'])\n",
+        "plt.plot(history.history['val_acc'])\n",
+        "plt.title('model accuracy')\n",
+        "plt.ylabel('accuracy')\n",
+        "plt.xlabel('epoch')\n",
+        "plt.legend(['train', 'validation'], loc='lower right')\n",
+        "plt.show()\n",
+        "\n",
+        "# summarize history for loss\n",
+        "plt.plot(history.history['loss'])\n",
+        "plt.plot(history.history['val_loss'])\n",
+        "plt.title('model loss')\n",
+        "plt.ylabel('loss')\n",
+        "plt.xlabel('epoch')\n",
+        "plt.legend(['train', 'validation'], loc='upper right')\n",
+        "plt.show()\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "Job-3uMvJKN_"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "Uw6YR76p_AF0"
+      },
+      "source": [
+        "## Saving models"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "name = \"lstm_\"+embedding_name+\"_s\"+str(maxOfInstancePerClass)"
+      ],
+      "metadata": {
+        "id": "irB4mATeeUcw"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "ykTp9lyRaAma"
+      },
+      "source": [
+        "model.save(path+name+\".h5\")\n"
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "5J4xDoqRUSfS"
+      },
+      "source": [
+        "# save embeddings\n",
+        "\n",
+        "# saving embeddings index \n"
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "HHlEtipG_Cp0"
+      },
+      "source": [
+        "## Loading models"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "fKt8ft1t_Cxx"
+      },
+      "source": [
+        "model = load_model(path+name+\".h5\")\n",
+        "\n",
+        "with open(path+tokenizer_filename, 'rb') as file:\n",
+        "  tokenizer = pickle.load(file)\n",
+        "\n",
+        "with open(path+encoder_filename, 'rb') as file:\n",
+        "  encoder = pickle.load(file)\n"
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "zbS4poso-3k7"
+      },
+      "source": [
+        "## Evaluation"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "df_test = pd.read_csv(test_path, sep=\"\\t\")\n"
+      ],
+      "metadata": {
+        "id": "KWORvsadvBbr"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "test_texts = df_test[columnText].tolist()\n",
+        "test_labels  = df_test[columnClass].tolist()\n",
+        "\n",
+        "test_sequences = tokenizer.texts_to_sequences(test_texts)\n",
+        "test_input = sequence.pad_sequences(test_sequences, maxlen=max_sequence_length)\n",
+        "\n",
+        "# Get predictions\n",
+        "test_predictions_probas = model.predict(test_input)\n",
+        "test_predictions = test_predictions_probas.argmax(axis=-1)"
+      ],
+      "metadata": {
+        "id": "Xr0o-0i5t38G"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "\n",
+        "\n",
+        "\n",
+        "test_intent_predictions = encoder.inverse_transform(test_predictions)\n",
+        "#test_intent_original = encoder.inverse_transform(test_labels)\n",
+        "\n",
+        "print('accuracy: ', sum(test_intent_predictions == test_labels) / len(test_labels))\n",
+        "print(\"Precision, Recall and F1-Score:\\n\\n\", classification_report(test_labels, test_intent_predictions))\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "lSn8yZ0gt3-d"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "name = \"test_\"+ name\n",
+        "\n",
+        "#classesName = encoder.classes_\n",
+        "#classes = [str(e) for e in encoder.transform(encoder.classes_)]\n",
+        "\n",
+        "report = classification_report(test_labels, test_intent_predictions, output_dict = True)\n",
+        "\n",
+        "precision = []\n",
+        "recall = []\n",
+        "f1 = []\n",
+        "support = []\n",
+        "dff = pd.DataFrame(columns= ['className', 'precision', 'recall', 'f1-score', 'support', 'FP', 'FN', 'TP', 'TN'])\n",
+        "for c in encoder.classes_:\n",
+        "  precision.append(report[c]['precision'])\n",
+        "  recall.append(report[c]['recall'])\n",
+        "  f1.append(report[c]['f1-score'])\n",
+        "  support.append(report[c]['support'])\n",
+        "\n",
+        "accuracy = report['accuracy']\n",
+        "weighted_avg = report['weighted avg']\n",
+        "\n",
+        "\n",
+        "cnf_matrix = confusion_matrix(test_labels, test_intent_predictions)\n",
+        "FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)\n",
+        "FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)\n",
+        "TP = np.diag(cnf_matrix)\n",
+        "TN = cnf_matrix.sum() - (FP + FN + TP)\n",
+        "\n",
+        "dff['className'] = encoder.classes_\n",
+        "dff['precision'] = precision\n",
+        "dff['recall'] = recall\n",
+        "dff['f1-score'] = f1\n",
+        "dff['support'] = support\n",
+        "dff['FP'] = FP\n",
+        "dff['FN'] = FN\n",
+        "dff['TP'] = TP\n",
+        "dff['TN'] = TN\n",
+        "\n",
+        "\n",
+        "\n",
+        "      \n",
+        "content = name + \"\\n\"\n",
+        "print(name)\n",
+        "content += str(weighted_avg) + \"\\n\"\n",
+        "print(weighted_avg)\n",
+        "print(accuracy)\n",
+        "print(dff)\n",
+        "\n",
+        "dff.to_csv(path+\"reports/report_\"+name+\".csv\", index=False)\n",
+        "\n",
+        "# enregistrer les predictions\n",
+        "pd.DataFrame({'labels': pd.Series(df_test[columnClass]), 'predictions': pd.Series(test_intent_predictions)}).to_csv(path+\"predictions/predictions_\"+name+\".csv\")\n",
+        "\n",
+        "with open(path+\"reports/report_\"+name+\".txt\", 'w') as f:\n",
+        "  f.write(content)\n"
+      ],
+      "metadata": {
+        "id": "RQ0LYGuOt4A4"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "Lbwg2H8sJRe7"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "4mX5g55AJRhj"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "-D8Gj6kzJRjv"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "Lwz5cO2eJRmD"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "yr_UWq14JRoI"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "WJM6J6_EJRqx"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "CtYR7NTvJRs2"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "Qppm6jATJRvM"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "rK5nK4gyJRx0"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "vSjWwcQKJRz7"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "G9pjdMdNW_KS"
+      },
+      "source": [
+        "predictions = model.predict(word_seq_validation)\n",
+        "predictions = np.argmax(predictions,axis=1)"
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "IHpVJ79IW_M0",
+        "outputId": "2e1657b3-04d1-42f1-ea8b-9bbcd4744108"
+      },
+      "source": [
+        "report = classification_report(predictions, y_validation, output_dict = True)\n",
+        "\n",
+        "accuracy = report['accuracy']\n",
+        "weighted_avg = report['weighted avg']\n",
+        "\n",
+        "print(accuracy, weighted_avg)"
+      ],
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "0.5726683109527725 {'precision': 0.6118028288513718, 'recall': 0.5726683109527725, 'f1-score': 0.5870482221489528, 'support': 10947}\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
+            "  _warn_prf(average, modifier, msg_start, len(result))\n",
+            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
+            "  _warn_prf(average, modifier, msg_start, len(result))\n",
+            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
+            "  _warn_prf(average, modifier, msg_start, len(result))\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "9SKjWffUW_PC"
+      },
+      "source": [
+        ""
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "LpgkGq-fW_RN"
+      },
+      "source": [
+        "df_test = pd.read_csv(test_path, sep=\"\\t\")\n",
+        "\n",
+        "encoder = preprocessing.LabelEncoder()\n",
+        "y_test = encoder.fit_transform(df_test[columnClass])\n"
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "Q9eYqi5SW_Ta",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "31e45f20-583a-4ca6-eac8-21863f6fef5b"
+      },
+      "source": [
+        "raw_docs_test = df_test[columnText].tolist()\n",
+        "\n",
+        "print(\"pre-processing test data...\")\n",
+        "\n",
+        "stop_words = set(stopwords.words('french'))\n",
+        "\n",
+        "processed_docs_test = []\n",
+        "for doc in tqdm(raw_docs_test):\n",
+        "    tokens = word_tokenize(doc, language='french')\n",
+        "    filtered = [word for word in tokens if word not in stop_words]\n",
+        "    processed_docs_test.append(\" \".join(filtered))\n",
+        "#end for\n",
+        "\n",
+        "print(\"tokenizing input data...\")\n",
+        "#tokenizer = Tokenizer(num_words=max_len, lower=True, char_level=False)\n",
+        "#tokenizer.fit_on_texts(processed_docs_train + processed_docs_validation)  #leaky\n",
+        "word_seq_test = tokenizer.texts_to_sequences(processed_docs_test)\n",
+        "\n",
+        "#pad sequences\n",
+        "word_seq_test = sequence.pad_sequences(word_seq_test, maxlen=max_len)"
+      ],
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "pre-processing test data...\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "100%|██████████| 13137/13137 [00:09<00:00, 1331.48it/s]\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "tokenizing input data...\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "_WjpJN-Bqjeb"
+      },
+      "source": [
+        "predictions = model.predict(word_seq_test)\n",
+        "predictions = np.argmax(predictions,axis=1)"
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "zUwjL_dQqjgx",
+        "outputId": "912642ad-95eb-413a-d074-8d4881a57359"
+      },
+      "source": [
+        "report = classification_report(predictions, y_test, output_dict = True)\n",
+        "\n",
+        "accuracy = report['accuracy']\n",
+        "weighted_avg = report['weighted avg']\n",
+        "\n",
+        "print(accuracy, weighted_avg)"
+      ],
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "0.5698409073608891 {'precision': 0.6081680700148677, 'recall': 0.5698409073608891, 'f1-score': 0.5847417616022411, 'support': 13137}\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
+            "  _warn_prf(average, modifier, msg_start, len(result))\n",
+            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
+            "  _warn_prf(average, modifier, msg_start, len(result))\n",
+            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
+            "  _warn_prf(average, modifier, msg_start, len(result))\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "ka6DcPe7qqvg",
+        "outputId": "0c8cfbe6-178d-4208-98ba-4ba688e32939"
+      },
+      "source": [
+        "from sklearn.metrics import confusion_matrix\n",
+        "\n",
+        "classesName = encoder.classes_\n",
+        "classes = [str(e) for e in encoder.transform(encoder.classes_)]\n",
+        "\n",
+        "precision = []\n",
+        "recall = []\n",
+        "f1 = []\n",
+        "support = []\n",
+        "dff = pd.DataFrame(columns= ['className', 'precision', 'recall', 'f1-score', 'support', 'FP', 'FN', 'TP', 'TN'])\n",
+        "for c in classes:\n",
+        "  precision.append(report[c]['precision'])\n",
+        "  recall.append(report[c]['recall'])\n",
+        "  f1.append(report[c]['f1-score'])\n",
+        "  support.append(report[c]['support'])\n",
+        "\n",
+        "accuracy = report['accuracy']\n",
+        "weighted_avg = report['weighted avg']\n",
+        "\n",
+        "\n",
+        "cnf_matrix = confusion_matrix(y_test, predictions)\n",
+        "FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)\n",
+        "FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)\n",
+        "TP = np.diag(cnf_matrix)\n",
+        "TN = cnf_matrix.sum() - (FP + FN + TP)\n",
+        "\n",
+        "dff['className'] = classesName\n",
+        "dff['precision'] = precision\n",
+        "dff['recall'] = recall\n",
+        "dff['f1-score'] = f1\n",
+        "dff['support'] = support\n",
+        "dff['FP'] = FP\n",
+        "dff['FN'] = FN\n",
+        "dff['TP'] = TP\n",
+        "dff['TN'] = TN\n",
+        "\n",
+        "print(\"test_cnn_s\"+str(maxOfInstancePerClass))\n",
+        "\n",
+        "print(weighted_avg)\n",
+        "print(accuracy)\n",
+        "print(dff)\n",
+        "\n",
+        "dff.to_csv(\"drive/MyDrive/Classification-EDdA/report_test_cnn_s\"+str(maxOfInstancePerClass)+\".csv\", index=False)"
+      ],
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "test_cnn_s10000\n",
+            "{'precision': 0.6081680700148677, 'recall': 0.5698409073608891, 'f1-score': 0.5847417616022411, 'support': 13137}\n",
+            "0.5698409073608891\n",
+            "                                      className  precision  ...    TP     TN\n",
+            "0               Agriculture - Economie rustique   0.216535  ...    55  12636\n",
+            "1                                      Anatomie   0.459821  ...   103  12768\n",
+            "2                                     Antiquité   0.287975  ...    91  12710\n",
+            "3                                  Architecture   0.339623  ...   108  12722\n",
+            "4                               Arts et métiers   0.015504  ...     2  12995\n",
+            "5                                    Beaux-arts   0.060000  ...     6  13018\n",
+            "6                       Belles-lettres - Poésie   0.127660  ...    30  12761\n",
+            "7                                        Blason   0.228571  ...    24  12993\n",
+            "8                                    Caractères   0.037037  ...     1  13110\n",
+            "9                                        Chasse   0.221311  ...    27  12962\n",
+            "10                                       Chimie   0.160714  ...    18  12991\n",
+            "11                                     Commerce   0.443418  ...   192  12490\n",
+            "12                        Droit - Jurisprudence   0.762879  ...  1081  11263\n",
+            "13                          Economie domestique   0.000000  ...     0  13102\n",
+            "14                                    Grammaire   0.408929  ...   229  12254\n",
+            "15                                   Géographie   0.917312  ...  2607   9910\n",
+            "16                                     Histoire   0.405063  ...   288  11777\n",
+            "17                           Histoire naturelle   0.743292  ...   831  11661\n",
+            "18                                          Jeu   0.061538  ...     4  13067\n",
+            "19                                       Marine   0.590805  ...   257  12549\n",
+            "20                           Maréchage - Manège   0.620690  ...    72  13001\n",
+            "21                                Mathématiques   0.549669  ...    83  12903\n",
+            "22                                       Mesure   0.095238  ...     4  13087\n",
+            "23              Militaire (Art) - Guerre - Arme   0.476351  ...   141  12704\n",
+            "24                                  Minéralogie   0.000000  ...     0  13111\n",
+            "25                                      Monnaie   0.054795  ...     4  13051\n",
+            "26                                      Musique   0.287500  ...    46  12904\n",
+            "27                                    Médailles   0.000000  ...     0  13107\n",
+            "28                         Médecine - Chirurgie   0.376218  ...   193  12149\n",
+            "29                                      Métiers   0.605634  ...   731  11047\n",
+            "30                                    Pharmacie   0.070423  ...     5  13045\n",
+            "31                                  Philosophie   0.071429  ...     8  12996\n",
+            "32  Physique - [Sciences physico-mathématiques]   0.378378  ...   112  12674\n",
+            "33                                    Politique   0.000000  ...     0  13110\n",
+            "34                                        Pêche   0.170213  ...     8  13069\n",
+            "35                                     Religion   0.326371  ...   125  12488\n",
+            "36                                    Spectacle   0.000000  ...     0  13121\n",
+            "37                                 Superstition   0.000000  ...     0  13112\n",
+            "\n",
+            "[38 rows x 9 columns]\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "BqJ1_hUUqqx5"
+      },
+      "source": [
+        ""
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "bhfuGNwIqrOQ"
+      },
+      "source": [
+        ""
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "NkL3MopyqrQk"
+      },
+      "source": [
+        ""
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "XLHl-pvzqjjI"
+      },
+      "source": [
+        ""
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "lLR_Xvi9qjlo"
+      },
+      "source": [
+        ""
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "8cGcLOFTqjoP"
+      },
+      "source": [
+        ""
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "vLGTnit_W_V8"
+      },
+      "source": [
+        ""
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "R-3lBXjDD9wE"
+      },
+      "source": [
+        "def predict(data, max_len):\n",
+        "  \n",
+        "  pad_sequ_test, _ = prepare_sequence(data, max_len)\n",
+        "  pred_labels_ = model.predict(pad_sequ_test)\n",
+        "\n",
+        "  return np.argmax(pred_labels_,axis=1)\n",
+        "\n",
+        "\n",
+        "def eval(data, labels, max_len):\n",
+        "  \n",
+        "  pred_labels_ = predict(data, max_len)\n",
+        "  report = classification_report(pred_labels_, labels, output_dict = True)\n",
+        "\n",
+        "  accuracy = report['accuracy']\n",
+        "  weighted_avg = report['weighted avg']\n",
+        "  \n",
+        "  print(accuracy, weighted_avg)"
+      ],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "6T3kAvKvExgc",
+        "outputId": "c6d4560e-fc64-4579-9adb-79c2e36d2386"
+      },
+      "source": [
+        "# evaluation sur le jeu de validation\n",
+        "eval(df_validation[columnText], y_validation, max_len)"
+      ],
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "/usr/local/lib/python3.7/dist-packages/zeugma/keras_transformers.py:33: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
+            "  return np.array(self.texts_to_sequences(texts))\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "0.06925290207361841 {'precision': 0.09108131158125257, 'recall': 0.06925290207361841, 'f1-score': 0.06099084715237025, 'support': 10079}\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "pTDJA03_-8yu",
+        "outputId": "d8bcdf73-c4c3-4c88-b063-90bd1cad5122"
+      },
+      "source": [
+        "# evaluation sur le jeu de test\n",
+        "df_test = pd.read_csv(test_path, sep=\"\\t\")\n",
+        "#df_test = resample_classes(df_test, columnClass, maxOfInstancePerClass)\n",
+        "\n",
+        "y_test = df_test[columnClass]\n",
+        "encoder = preprocessing.LabelEncoder()\n",
+        "y_test = encoder.fit_transform(y_test)\n",
+        "\n",
+        "eval(df_test[columnText], y_test, max_len)\n"
+      ],
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "/usr/local/lib/python3.7/dist-packages/zeugma/keras_transformers.py:33: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
+            "  return np.array(self.texts_to_sequences(texts))\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "0.07231483595950369 {'precision': 0.081194635559303, 'recall': 0.07231483595950369, 'f1-score': 0.06322383877903374, 'support': 13137}\n"
+          ]
+        }
+      ]
+    }
+  ]
+}
\ No newline at end of file
-- 
GitLab