{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "EDdA-Classification_CNN_Conv1D-EGC.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0yFsoHXX8Iyy"
      },
      "source": [
        "# Deep learning for EDdA classification"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tFlUCDL2778i"
      },
      "source": [
        "## Setup colab environment"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Sp8d_Uus7SHJ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "976ed0dd-7aeb-4f64-e34b-117733abf38c"
      },
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jQBu-p6hBU-j"
      },
      "source": [
        "### Install packages"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bTIXsF6kBUdh"
      },
      "source": [
        "#!pip install zeugma\n",
        "#!pip install plot_model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "56-04SNF8BMx"
      },
      "source": [
        "### Import librairies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HwWkSznz7SEv"
      },
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import pickle\n",
        "import os\n",
        "\n",
        "from tqdm import tqdm\n",
        "import requests, zipfile, io\n",
        "import codecs\n",
        "\n",
        "from sklearn import preprocessing # LabelEncoder\n",
        "from sklearn.metrics import classification_report\n",
        "from sklearn.metrics import confusion_matrix\n",
        "\n",
        "from keras.preprocessing import sequence\n",
        "from keras.preprocessing.text import Tokenizer\n",
        "\n",
        "from keras.layers import BatchNormalization, Input, Reshape, Conv1D, MaxPool1D, Conv2D, MaxPool2D, Concatenate\n",
        "from keras.layers import Embedding, Dropout, Flatten, Dense\n",
        "from keras.models import Model, load_model\n",
        "from keras.callbacks import ModelCheckpoint\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xrekV6W978l4"
      },
      "source": [
        "### Utils functions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4LJ5blQR7PUe"
      },
      "source": [
        "\n",
        "def resample_classes(df, classColumnName, numberOfInstances):\n",
        "  #random numberOfInstances elements\n",
        "  replace = False  # with replacement\n",
        "  fn = lambda obj: obj.loc[np.random.choice(obj.index, numberOfInstances if len(obj) > numberOfInstances else len(obj), replace),:]\n",
        "  return df.groupby(classColumnName, as_index=False).apply(fn)\n",
        "    \n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MtLr35eM753e"
      },
      "source": [
        "## Load Data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FnbNT4NF7zal",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c2a72d94-c7ae-4e6a-b962-ec4677053555"
      },
      "source": [
        "!wget https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/training_set.tsv\n",
        "!wget https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/test_set.tsv"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2022-02-17 19:08:55--  https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/training_set.tsv\n",
            "Resolving projet.liris.cnrs.fr (projet.liris.cnrs.fr)... 134.214.142.28\n",
            "Connecting to projet.liris.cnrs.fr (projet.liris.cnrs.fr)|134.214.142.28|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 175634219 (167M) [text/tab-separated-values]\n",
            "Saving to: ‘training_set.tsv’\n",
            "\n",
            "training_set.tsv    100%[===================>] 167.50M  28.2MB/s    in 6.5s    \n",
            "\n",
            "2022-02-17 19:09:02 (25.7 MB/s) - ‘training_set.tsv’ saved [175634219/175634219]\n",
            "\n",
            "--2022-02-17 19:09:02--  https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/test_set.tsv\n",
            "Resolving projet.liris.cnrs.fr (projet.liris.cnrs.fr)... 134.214.142.28\n",
            "Connecting to projet.liris.cnrs.fr (projet.liris.cnrs.fr)|134.214.142.28|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 42730598 (41M) [text/tab-separated-values]\n",
            "Saving to: ‘test_set.tsv’\n",
            "\n",
            "test_set.tsv        100%[===================>]  40.75M  19.7MB/s    in 2.1s    \n",
            "\n",
            "2022-02-17 19:09:05 (19.7 MB/s) - ‘test_set.tsv’ saved [42730598/42730598]\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Loading dataset"
      ],
      "metadata": {
        "id": "UHushJ1XfUj9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "train_path = 'training_set.tsv'\n",
        "test_path =  'test_set.tsv'"
      ],
      "metadata": {
        "id": "Q4te2c0bfvaJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nRLaQUO97zcq"
      },
      "source": [
        "df_train = pd.read_csv(train_path, sep=\"\\t\")\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_train.sample(5)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 548
        },
        "id": "2MvHEc7zVK1N",
        "outputId": "a934bbce-4ebe-4d5a-eb4e-18db1ce10532"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-24c25d4d-881e-4c52-8629-da71863f5656\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>volume</th>\n",
              "      <th>numero</th>\n",
              "      <th>head</th>\n",
              "      <th>normClass</th>\n",
              "      <th>classEDdA</th>\n",
              "      <th>author</th>\n",
              "      <th>id_enccre</th>\n",
              "      <th>domaine_enccre</th>\n",
              "      <th>ensemble_domaine_enccre</th>\n",
              "      <th>content</th>\n",
              "      <th>contentWithoutClass</th>\n",
              "      <th>firstParagraph</th>\n",
              "      <th>nb_words</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>17253</th>\n",
              "      <td>1</td>\n",
              "      <td>4938</td>\n",
              "      <td>Auge</td>\n",
              "      <td>Verrerie</td>\n",
              "      <td>dans les Verreries</td>\n",
              "      <td>unsigned</td>\n",
              "      <td>v1-3613-9</td>\n",
              "      <td>verrerie</td>\n",
              "      <td>Métiers</td>\n",
              "      <td>Auge, dans les Verreries, ce sont de gros hêtr...</td>\n",
              "      <td>auge gros hêtre \\n creusés tient plein eau ser...</td>\n",
              "      <td>auge gros hêtre \\n creusés tient plein eau ser...</td>\n",
              "      <td>70</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>16170</th>\n",
              "      <td>9</td>\n",
              "      <td>2017</td>\n",
              "      <td>Lettres de reprise</td>\n",
              "      <td>unclassified</td>\n",
              "      <td>unclassified</td>\n",
              "      <td>Boucher d'Argis</td>\n",
              "      <td>v9-1324-120</td>\n",
              "      <td>jurisprudence</td>\n",
              "      <td>Droit - Jurisprudence</td>\n",
              "      <td>Lettres de reprise, sont une commission que\\nl...</td>\n",
              "      <td>lettre reprise commission \\n prend chancelleri...</td>\n",
              "      <td>lettre reprise commission \\n prend chancelleri...</td>\n",
              "      <td>36</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>12041</th>\n",
              "      <td>16</td>\n",
              "      <td>1213</td>\n",
              "      <td>THOLUS</td>\n",
              "      <td>Architecture romaine</td>\n",
              "      <td>Archit. rom.</td>\n",
              "      <td>Jaucourt</td>\n",
              "      <td>v16-709-0</td>\n",
              "      <td>architecture</td>\n",
              "      <td>Architecture</td>\n",
              "      <td>THOLUS, s. m. (Archit. rom.) Vitruve nomme\\nth...</td>\n",
              "      <td>tholus s. m.   vitruve nomme \\n tholus coupe d...</td>\n",
              "      <td>tholus s. m.   vitruve nomme \\n tholus coupe d...</td>\n",
              "      <td>95</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>20046</th>\n",
              "      <td>9</td>\n",
              "      <td>4327</td>\n",
              "      <td>MALLIENS, les</td>\n",
              "      <td>Géographie ancienne</td>\n",
              "      <td>Geog. anc.</td>\n",
              "      <td>Jaucourt</td>\n",
              "      <td>v9-2589-0</td>\n",
              "      <td>géographie</td>\n",
              "      <td>Géographie</td>\n",
              "      <td>MALLIENS, les, (Géog. anc.) en latin Malli ;\\n...</td>\n",
              "      <td>malliens géog anc latin malli \\n ancien peuple...</td>\n",
              "      <td>malliens géog anc latin malli \\n ancien peuple...</td>\n",
              "      <td>71</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>35783</th>\n",
              "      <td>9</td>\n",
              "      <td>3883</td>\n",
              "      <td>MAGDOLOS</td>\n",
              "      <td>Géographie ancienne</td>\n",
              "      <td>Geog. anc.</td>\n",
              "      <td>Jaucourt</td>\n",
              "      <td>v9-2377-0</td>\n",
              "      <td>géographie</td>\n",
              "      <td>Géographie</td>\n",
              "      <td>MAGDOLOS, (Géog. anc.) ville d'Egypte dont\\npa...</td>\n",
              "      <td>magdolos géog anc ville egypte \\n parlent jéré...</td>\n",
              "      <td>magdolos géog anc ville egypte \\n parlent jéré...</td>\n",
              "      <td>50</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-24c25d4d-881e-4c52-8629-da71863f5656')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-24c25d4d-881e-4c52-8629-da71863f5656 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-24c25d4d-881e-4c52-8629-da71863f5656');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "       volume  ...  nb_words\n",
              "17253       1  ...        70\n",
              "16170       9  ...        36\n",
              "12041      16  ...        95\n",
              "20046       9  ...        71\n",
              "35783       9  ...        50\n",
              "\n",
              "[5 rows x 13 columns]"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Configuration\n"
      ],
      "metadata": {
        "id": "-63bh_cKfN4p"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "columnText = 'contentWithoutClass'\n",
        "columnClass = 'ensemble_domaine_enccre'\n",
        "\n",
        "maxOfInstancePerClass = 10000\n",
        "\n",
        "batch_size = 64\n",
        "validation_split = 0.20\n",
        "max_nb_words = 20000        # taille du vocabulaire\n",
        "max_sequence_length = 512   # taille max du 'document' \n",
        "epochs = 10\n",
        "\n",
        "#embedding_name = \"fasttext\" \n",
        "#embedding_dim = 300 \n",
        "\n",
        "embedding_name = \"glove.6B.100d\"\n",
        "embedding_dim = 100 \n",
        "\n",
        "path = \"drive/MyDrive/Classification-EDdA/\"\n",
        "encoder_filename = \"label_encoder.pkl\"\n",
        "tokenizer_filename = \"tokenizer_keras.pkl\""
      ],
      "metadata": {
        "id": "nsRuyzYUfOBg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Preprocessing\n"
      ],
      "metadata": {
        "id": "ZDz-Y1LCfQt0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "if maxOfInstancePerClass != 10000:\n",
        "  df_train = resample_classes(df_train, columnClass, maxOfInstancePerClass)"
      ],
      "metadata": {
        "id": "cP5e7DvRvwxh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vGWAgBH87ze8"
      },
      "source": [
        "labels  = df_train[columnClass]\n",
        "numberOfClasses = labels.nunique()\n",
        "\n",
        "if os.path.isfile(path+encoder_filename):    \n",
        "    # load existing encoder \n",
        "    with open(path+encoder_filename, 'rb') as file:\n",
        "      encoder = pickle.load(file)\n",
        "\n",
        "else:\n",
        "  encoder = preprocessing.LabelEncoder()\n",
        "  encoder.fit(labels)\n",
        "\n",
        "  with open(path+encoder_filename, 'wb') as file:\n",
        "      pickle.dump(encoder, file)\n",
        "\n",
        "\n",
        "labels = encoder.transform(labels)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "encoder.classes_"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SME4vvVhW9Sn",
        "outputId": "8b577b93-69a5-47e3-ea3b-6a2e1d144914"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array(['Agriculture - Economie rustique', 'Anatomie', 'Antiquité',\n",
              "       'Architecture', 'Arts et métiers', 'Beaux-arts',\n",
              "       'Belles-lettres - Poésie', 'Blason', 'Caractères', 'Chasse',\n",
              "       'Chimie', 'Commerce', 'Droit - Jurisprudence',\n",
              "       'Economie domestique', 'Grammaire', 'Géographie', 'Histoire',\n",
              "       'Histoire naturelle', 'Jeu', 'Marine', 'Maréchage - Manège',\n",
              "       'Mathématiques', 'Mesure', 'Militaire (Art) - Guerre - Arme',\n",
              "       'Minéralogie', 'Monnaie', 'Musique', 'Médailles',\n",
              "       'Médecine - Chirurgie', 'Métiers', 'Pharmacie', 'Philosophie',\n",
              "       'Physique - [Sciences physico-mathématiques]', 'Politique',\n",
              "       'Pêche', 'Religion', 'Spectacle', 'Superstition'], dtype=object)"
            ]
          },
          "metadata": {},
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "labels_index = dict(zip(list(encoder.classes_), encoder.transform(list(encoder.classes_))))"
      ],
      "metadata": {
        "id": "nIzWQ2VbW_UO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "labels_index"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4e7ggEGiXC_W",
        "outputId": "8a22e814-c63c-4e0d-adb9-899c11a1dc9b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'Agriculture - Economie rustique': 0,\n",
              " 'Anatomie': 1,\n",
              " 'Antiquité': 2,\n",
              " 'Architecture': 3,\n",
              " 'Arts et métiers': 4,\n",
              " 'Beaux-arts': 5,\n",
              " 'Belles-lettres - Poésie': 6,\n",
              " 'Blason': 7,\n",
              " 'Caractères': 8,\n",
              " 'Chasse': 9,\n",
              " 'Chimie': 10,\n",
              " 'Commerce': 11,\n",
              " 'Droit - Jurisprudence': 12,\n",
              " 'Economie domestique': 13,\n",
              " 'Grammaire': 14,\n",
              " 'Géographie': 15,\n",
              " 'Histoire': 16,\n",
              " 'Histoire naturelle': 17,\n",
              " 'Jeu': 18,\n",
              " 'Marine': 19,\n",
              " 'Maréchage - Manège': 20,\n",
              " 'Mathématiques': 21,\n",
              " 'Mesure': 22,\n",
              " 'Militaire (Art) - Guerre - Arme': 23,\n",
              " 'Minéralogie': 24,\n",
              " 'Monnaie': 25,\n",
              " 'Musique': 26,\n",
              " 'Médailles': 27,\n",
              " 'Médecine - Chirurgie': 28,\n",
              " 'Métiers': 29,\n",
              " 'Pharmacie': 30,\n",
              " 'Philosophie': 31,\n",
              " 'Physique - [Sciences physico-mathématiques]': 32,\n",
              " 'Politique': 33,\n",
              " 'Pêche': 34,\n",
              " 'Religion': 35,\n",
              " 'Spectacle': 36,\n",
              " 'Superstition': 37}"
            ]
          },
          "metadata": {},
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Loading pre-trained embeddings"
      ],
      "metadata": {
        "id": "CoQSwXfNEOOx"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### FastText"
      ],
      "metadata": {
        "id": "PUn9G_LYEYvM"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# download FastText (prend trop de place pour le laisser sur le drive)\n",
        "zip_file_url = \"https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip\"\n",
        "r = requests.get(zip_file_url)\n",
        "z = zipfile.ZipFile(io.BytesIO(r.content))\n",
        "z.extractall()"
      ],
      "metadata": {
        "id": "qKQFmUs0EY5E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print('loading word embeddings FastText...')\n",
        "\n",
        "embeddings_index = {}\n",
        "f = codecs.open('crawl-300d-2M.vec', encoding='utf-8')\n",
        "\n",
        "for line in tqdm(f):\n",
        "    values = line.rstrip().rsplit(' ')\n",
        "    word = values[0]\n",
        "    coefs = np.asarray(values[1:], dtype='float32')\n",
        "    embeddings_index[word] = coefs\n",
        "f.close()\n",
        "\n",
        "print('found %s word vectors' % len(embeddings_index))"
      ],
      "metadata": {
        "id": "x0C0XwhwEY7Z"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### GLOVE"
      ],
      "metadata": {
        "id": "q81MEgYrEWsj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# download Glove\n",
        "#zip_file_url = \"https://nlp.stanford.edu/data/glove.6B.zip\"\n",
        "#r = requests.get(zip_file_url)\n",
        "#z = zipfile.ZipFile(io.BytesIO(r.content))\n",
        "#z.extractall()"
      ],
      "metadata": {
        "id": "qTHuXs2EEV-M"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print('loading word embeddings GLOVE...')\n",
        "\n",
        "embeddings_index = {}\n",
        "f = open(path+\"embeddings/\"+embedding_name+\".txt\", encoding='utf-8')\n",
        "for line in tqdm(f):\n",
        "    values = line.split()\n",
        "    word = values[0]\n",
        "    coefs = np.asarray(values[1:], dtype='float32')\n",
        "    embeddings_index[word] = coefs\n",
        "f.close()\n",
        "\n",
        "print('Found %s word vectors.' % len(embeddings_index))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "U0hwICJzEOVL",
        "outputId": "be478f5b-6c4a-48ee-e05d-49b2e7fdd285"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading word embeddings GLOVE...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "400000it [00:12, 31570.08it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Found 400000 word vectors.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HuUVfklf-dSR"
      },
      "source": [
        "## Training models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NTNh6kMTp_eU",
        "outputId": "fba63a44-73c5-4928-ad7e-7356b17af073"
      },
      "source": [
        "\n",
        "raw_docs_train = df_train[columnText].tolist()\n",
        "\n",
        "\n",
        "print(\"pre-processing train data...\")\n",
        "\n",
        "if os.path.isfile(path+tokenizer_filename):\n",
        "  with open(path+tokenizer_filename, 'rb') as file:\n",
        "      tokenizer = pickle.load(file)\n",
        "else:\n",
        "  tokenizer = Tokenizer(num_words = max_nb_words)\n",
        "  tokenizer.fit_on_texts(raw_docs_train) \n",
        "\n",
        "  with open(path+tokenizer_filename, 'wb') as file:\n",
        "        pickle.dump(tokenizer, file)\n",
        "\n",
        "sequences = tokenizer.texts_to_sequences(raw_docs_train)\n",
        "\n",
        "word_index = tokenizer.word_index\n",
        "print(\"dictionary size: \", len(word_index))\n",
        "\n",
        "#pad sequences\n",
        "data = sequence.pad_sequences(sequences, maxlen=max_sequence_length)\n",
        "\n",
        "print('Shape of data tensor:', data.shape)\n",
        "print('Shape of label tensor:', labels.shape)\n",
        "#print(labels)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "pre-processing train data...\n",
            "dictionary size:  190508\n",
            "Shape of data tensor: (46807, 512)\n",
            "Shape of label tensor: (46807,)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# split the data into a training set and a validation set\n",
        "\n",
        "indices = np.arange(data.shape[0])\n",
        "np.random.shuffle(indices)\n",
        "data = data[indices]\n",
        "labels = labels[indices]\n",
        "\n",
        "nb_validation_samples = int(validation_split * data.shape[0])\n",
        "\n",
        "x_train = data[:-nb_validation_samples]\n",
        "y_train = labels[:-nb_validation_samples]\n",
        "x_val = data[-nb_validation_samples:]\n",
        "y_val = labels[-nb_validation_samples:]\n"
      ],
      "metadata": {
        "id": "sHYJ4P-YDfFb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wGjQI0YgpQAS",
        "outputId": "ab38f7a8-8c10-471f-88e6-10a62ae1776c"
      },
      "source": [
        "#embedding matrix\n",
        "\n",
        "print('preparing embedding matrix...')\n",
        "\n",
        "embedding_matrix = np.zeros((len(word_index)+1, embedding_dim))\n",
        "\n",
        "for word, i in word_index.items():\n",
        "    embedding_vector = embeddings_index.get(word)\n",
        "    if embedding_vector is not None : \n",
        "        embedding_matrix[i] = embedding_vector\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "preparing embedding matrix...\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "#filter_sizes = [2,  3, 5]\n",
        "#drop = 0.5\n",
        "\n",
        "embedding_layer = Embedding(len(word_index)+1, embedding_dim, input_length = max_sequence_length,\n",
        "                    weights=[embedding_matrix], trainable=False)\n",
        "inputs = Input(shape=(max_sequence_length), dtype='int32')\n",
        "embedding = embedding_layer(inputs)\n",
        "\n",
        "print(embedding.shape)\n",
        "#reshape = Reshape((max_sequence_length, embedding_dim, 1))(embedding)\n",
        "#print(reshape.shape)\n",
        "\n",
        "# architecture testée par Khaled\n",
        "\n",
        "conv_0 = Conv1D(64, 5, activation='relu')(embedding)\n",
        "#conv_1 = Conv1D(128, 5, activation='relu')(embedding)\n",
        "#conv_2 = Conv1D(128, 5, activation='relu')(embedding)\n",
        "\n",
        "maxpool_0 = MaxPool1D(pool_size=(max_sequence_length - 5 + 1))(conv_0)\n",
        "#maxpool_1 = MaxPool1D(5)(conv_1)\n",
        "#maxpool_2 = MaxPool1D(35)(conv_2)\n",
        "\n",
        "#concatenated_tensor = Concatenate(axis=1)([maxpool_0, maxpool_1, maxpool_2])\n",
        "flatten = Flatten()(maxpool_0)\n",
        "#dropout = Dropout(drop)(flatten)\n",
        "output = Dense(len(labels_index), activation='softmax')(flatten)\n",
        "\n",
        "# this creates a model that includes\n",
        "model = Model(inputs=inputs, outputs=output)\n",
        "\n",
        "checkpoint = ModelCheckpoint('weights_cnn_sentece.hdf5', monitor='val_acc', verbose=1, save_best_only=True, mode='auto')\n",
        "#adam = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)\n",
        "\n",
        "model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])\n",
        "model.summary()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OUphqlYJCC9n",
        "outputId": "7b8a479d-ba34-4ea7-b714-107b9c7166b3"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "(None, 512, 100)\n",
            "Model: \"model\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " input_1 (InputLayer)        [(None, 512)]             0         \n",
            "                                                                 \n",
            " embedding (Embedding)       (None, 512, 100)          19050900  \n",
            "                                                                 \n",
            " conv1d (Conv1D)             (None, 508, 64)           32064     \n",
            "                                                                 \n",
            " max_pooling1d (MaxPooling1D  (None, 1, 64)            0         \n",
            " )                                                               \n",
            "                                                                 \n",
            " flatten (Flatten)           (None, 64)                0         \n",
            "                                                                 \n",
            " dense (Dense)               (None, 38)                2470      \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 19,085,434\n",
            "Trainable params: 34,534\n",
            "Non-trainable params: 19,050,900\n",
            "_________________________________________________________________\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "history = model.fit(x_train, y_train, \n",
        "                    batch_size=batch_size, \n",
        "                    epochs=epochs, \n",
        "                    verbose=1,\n",
        "                    callbacks=[checkpoint],\n",
        "                    validation_data=(x_val, y_val))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3aUBdLdNCEGK",
        "outputId": "591e58dc-c403-445a-dab6-3d902f0f6465"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/10\n",
            "585/586 [============================>.] - ETA: 0s - loss: 2.0486 - acc: 0.4831\n",
            "Epoch 1: val_acc improved from -inf to 0.56832, saving model to weights_cnn_sentece.hdf5\n",
            "586/586 [==============================] - 93s 157ms/step - loss: 2.0484 - acc: 0.4831 - val_loss: 1.6547 - val_acc: 0.5683\n",
            "Epoch 2/10\n",
            "585/586 [============================>.] - ETA: 0s - loss: 1.4558 - acc: 0.6155\n",
            "Epoch 2: val_acc improved from 0.56832 to 0.61949, saving model to weights_cnn_sentece.hdf5\n",
            "586/586 [==============================] - 85s 145ms/step - loss: 1.4557 - acc: 0.6156 - val_loss: 1.4356 - val_acc: 0.6195\n",
            "Epoch 3/10\n",
            "585/586 [============================>.] - ETA: 0s - loss: 1.2437 - acc: 0.6631\n",
            "Epoch 3: val_acc improved from 0.61949 to 0.63829, saving model to weights_cnn_sentece.hdf5\n",
            "586/586 [==============================] - 84s 143ms/step - loss: 1.2439 - acc: 0.6631 - val_loss: 1.3358 - val_acc: 0.6383\n",
            "Epoch 4/10\n",
            "585/586 [============================>.] - ETA: 0s - loss: 1.1175 - acc: 0.6942\n",
            "Epoch 4: val_acc improved from 0.63829 to 0.65111, saving model to weights_cnn_sentece.hdf5\n",
            "586/586 [==============================] - 84s 143ms/step - loss: 1.1176 - acc: 0.6941 - val_loss: 1.2895 - val_acc: 0.6511\n",
            "Epoch 5/10\n",
            "585/586 [============================>.] - ETA: 0s - loss: 1.0243 - acc: 0.7172\n",
            "Epoch 5: val_acc improved from 0.65111 to 0.65356, saving model to weights_cnn_sentece.hdf5\n",
            "586/586 [==============================] - 84s 143ms/step - loss: 1.0242 - acc: 0.7172 - val_loss: 1.2751 - val_acc: 0.6536\n",
            "Epoch 6/10\n",
            "585/586 [============================>.] - ETA: 0s - loss: 0.9492 - acc: 0.7371\n",
            "Epoch 6: val_acc improved from 0.65356 to 0.65987, saving model to weights_cnn_sentece.hdf5\n",
            "586/586 [==============================] - 92s 158ms/step - loss: 0.9491 - acc: 0.7371 - val_loss: 1.2598 - val_acc: 0.6599\n",
            "Epoch 7/10\n",
            "585/586 [============================>.] - ETA: 0s - loss: 0.8892 - acc: 0.7536\n",
            "Epoch 7: val_acc did not improve from 0.65987\n",
            "586/586 [==============================] - 86s 147ms/step - loss: 0.8892 - acc: 0.7536 - val_loss: 1.2598 - val_acc: 0.6557\n",
            "Epoch 8/10\n",
            "585/586 [============================>.] - ETA: 0s - loss: 0.8387 - acc: 0.7659\n",
            "Epoch 8: val_acc improved from 0.65987 to 0.66179, saving model to weights_cnn_sentece.hdf5\n",
            "586/586 [==============================] - 85s 145ms/step - loss: 0.8387 - acc: 0.7659 - val_loss: 1.2452 - val_acc: 0.6618\n",
            "Epoch 9/10\n",
            "585/586 [============================>.] - ETA: 0s - loss: 0.7950 - acc: 0.7780\n",
            "Epoch 9: val_acc improved from 0.66179 to 0.66275, saving model to weights_cnn_sentece.hdf5\n",
            "586/586 [==============================] - 83s 142ms/step - loss: 0.7950 - acc: 0.7780 - val_loss: 1.2593 - val_acc: 0.6627\n",
            "Epoch 10/10\n",
            "585/586 [============================>.] - ETA: 0s - loss: 0.7575 - acc: 0.7873\n",
            "Epoch 10: val_acc improved from 0.66275 to 0.66286, saving model to weights_cnn_sentece.hdf5\n",
            "586/586 [==============================] - 83s 141ms/step - loss: 0.7575 - acc: 0.7873 - val_loss: 1.2646 - val_acc: 0.6629\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "ZcYbQsQEJKLq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "plt.plot(history.history['acc'])\n",
        "plt.plot(history.history['val_acc'])\n",
        "plt.title('model accuracy')\n",
        "plt.ylabel('accuracy')\n",
        "plt.xlabel('epoch')\n",
        "plt.legend(['train', 'validation'], loc='lower right')\n",
        "plt.show()\n",
        "\n",
        "# summarize history for loss\n",
        "plt.plot(history.history['loss'])\n",
        "plt.plot(history.history['val_loss'])\n",
        "plt.title('model loss')\n",
        "plt.ylabel('loss')\n",
        "plt.xlabel('epoch')\n",
        "plt.legend(['train', 'validation'], loc='upper right')\n",
        "plt.show()\n",
        "\n"
      ],
      "metadata": {
        "id": "Job-3uMvJKN_",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 573
        },
        "outputId": "612f25b7-e9ad-43c6-e084-041e2a5f3b64"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uw6YR76p_AF0"
      },
      "source": [
        "## Saving models"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "name = \"cnn_conv1D_egc_\"+embedding_name+\"_s\"+str(maxOfInstancePerClass)"
      ],
      "metadata": {
        "id": "irB4mATeeUcw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ykTp9lyRaAma"
      },
      "source": [
        "model.save(path+name+\".h5\")\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5J4xDoqRUSfS"
      },
      "source": [
        "# save embeddings\n",
        "\n",
        "# saving embeddings index \n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HHlEtipG_Cp0"
      },
      "source": [
        "## Loading models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fKt8ft1t_Cxx"
      },
      "source": [
        "model = load_model(path+name+\".h5\")\n",
        "\n",
        "with open(path+tokenizer_filename, 'rb') as file:\n",
        "  tokenizer = pickle.load(file)\n",
        "\n",
        "with open(path+encoder_filename, 'rb') as file:\n",
        "  encoder = pickle.load(file)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zbS4poso-3k7"
      },
      "source": [
        "## Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "df_test = pd.read_csv(test_path, sep=\"\\t\")\n"
      ],
      "metadata": {
        "id": "KWORvsadvBbr"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "test_texts = df_test[columnText].tolist()\n",
        "test_labels  = df_test[columnClass].tolist()\n",
        "\n",
        "test_sequences = tokenizer.texts_to_sequences(test_texts)\n",
        "test_input = sequence.pad_sequences(test_sequences, maxlen=max_sequence_length)\n",
        "\n",
        "# Get predictions\n",
        "test_predictions_probas = model.predict(test_input)\n",
        "test_predictions = test_predictions_probas.argmax(axis=-1)"
      ],
      "metadata": {
        "id": "Xr0o-0i5t38G"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "\n",
        "\n",
        "test_intent_predictions = encoder.inverse_transform(test_predictions)\n",
        "#test_intent_original = encoder.inverse_transform(test_labels)\n",
        "\n",
        "print('accuracy: ', sum(test_intent_predictions == test_labels) / len(test_labels))\n",
        "print(\"Precision, Recall and F1-Score:\\n\\n\", classification_report(test_labels, test_intent_predictions))\n",
        "\n"
      ],
      "metadata": {
        "id": "lSn8yZ0gt3-d",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b8077e36-fa15-467d-9bf8-0877982c6f80"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "accuracy:  0.6617672192787558\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Precision, Recall and F1-Score:\n",
            "\n",
            "                                              precision    recall  f1-score   support\n",
            "\n",
            "            Agriculture - Economie rustique       0.40      0.28      0.33       233\n",
            "                                   Anatomie       0.79      0.53      0.63       215\n",
            "                                  Antiquité       0.47      0.50      0.48       272\n",
            "                               Architecture       0.54      0.47      0.51       278\n",
            "                            Arts et métiers       0.37      0.13      0.20       112\n",
            "                                 Beaux-arts       0.49      0.30      0.37        86\n",
            "                    Belles-lettres - Poésie       0.31      0.22      0.26       206\n",
            "                                     Blason       0.59      0.48      0.53       108\n",
            "                                 Caractères       1.00      0.09      0.16        23\n",
            "                                     Chasse       0.66      0.53      0.58       116\n",
            "                                     Chimie       0.40      0.31      0.35       104\n",
            "                                   Commerce       0.48      0.54      0.51       376\n",
            "                      Droit - Jurisprudence       0.79      0.76      0.77      1284\n",
            "                        Economie domestique       0.20      0.04      0.06        27\n",
            "                                  Grammaire       0.43      0.38      0.40       452\n",
            "                                 Géographie       0.95      0.94      0.95      2621\n",
            "                                   Histoire       0.39      0.53      0.45       616\n",
            "                         Histoire naturelle       0.77      0.82      0.79       963\n",
            "                                        Jeu       0.63      0.64      0.64        56\n",
            "                                     Marine       0.65      0.71      0.68       415\n",
            "                         Maréchage - Manège       0.85      0.72      0.78       105\n",
            "                              Mathématiques       0.56      0.61      0.58       140\n",
            "                                     Mesure       0.33      0.05      0.09        37\n",
            "            Militaire (Art) - Guerre - Arme       0.57      0.63      0.60       258\n",
            "                                Minéralogie       0.10      0.05      0.06        22\n",
            "                                    Monnaie       0.27      0.13      0.17        63\n",
            "                                    Musique       0.73      0.53      0.61       137\n",
            "                                  Médailles       0.86      0.26      0.40        23\n",
            "                       Médecine - Chirurgie       0.51      0.64      0.57       455\n",
            "                                    Métiers       0.53      0.67      0.59      1051\n",
            "                                  Pharmacie       0.39      0.14      0.20        65\n",
            "                                Philosophie       0.41      0.27      0.32        94\n",
            "Physique - [Sciences physico-mathématiques]       0.52      0.56      0.54       265\n",
            "                                  Politique       0.50      0.04      0.08        23\n",
            "                                      Pêche       0.75      0.43      0.55        42\n",
            "                                   Religion       0.54      0.55      0.55       328\n",
            "                                  Spectacle       0.00      0.00      0.00         9\n",
            "                               Superstition       0.86      0.27      0.41        22\n",
            "\n",
            "                                   accuracy                           0.66     11702\n",
            "                                  macro avg       0.54      0.41      0.44     11702\n",
            "                               weighted avg       0.66      0.66      0.66     11702\n",
            "\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "name = \"test_\"+ name\n",
        "\n",
        "#classesName = encoder.classes_\n",
        "#classes = [str(e) for e in encoder.transform(encoder.classes_)]\n",
        "\n",
        "report = classification_report(test_labels, test_intent_predictions, output_dict = True)\n",
        "\n",
        "precision = []\n",
        "recall = []\n",
        "f1 = []\n",
        "support = []\n",
        "dff = pd.DataFrame(columns= ['className', 'precision', 'recall', 'f1-score', 'support', 'FP', 'FN', 'TP', 'TN'])\n",
        "for c in encoder.classes_:\n",
        "  precision.append(report[c]['precision'])\n",
        "  recall.append(report[c]['recall'])\n",
        "  f1.append(report[c]['f1-score'])\n",
        "  support.append(report[c]['support'])\n",
        "\n",
        "accuracy = report['accuracy']\n",
        "weighted_avg = report['weighted avg']\n",
        "\n",
        "\n",
        "cnf_matrix = confusion_matrix(test_labels, test_intent_predictions)\n",
        "FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)\n",
        "FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)\n",
        "TP = np.diag(cnf_matrix)\n",
        "TN = cnf_matrix.sum() - (FP + FN + TP)\n",
        "\n",
        "dff['className'] = encoder.classes_\n",
        "dff['precision'] = precision\n",
        "dff['recall'] = recall\n",
        "dff['f1-score'] = f1\n",
        "dff['support'] = support\n",
        "dff['FP'] = FP\n",
        "dff['FN'] = FN\n",
        "dff['TP'] = TP\n",
        "dff['TN'] = TN\n",
        "\n",
        "\n",
        "\n",
        "      \n",
        "content = name + \"\\n\"\n",
        "print(name)\n",
        "content += str(weighted_avg) + \"\\n\"\n",
        "print(weighted_avg)\n",
        "print(accuracy)\n",
        "print(dff)\n",
        "\n",
        "dff.to_csv(path+\"reports/report_\"+name+\".csv\", index=False)\n",
        "\n",
        "# enregistrer les predictions\n",
        "pd.DataFrame({'labels': pd.Series(df_test[columnClass]), 'predictions': pd.Series(test_intent_predictions)}).to_csv(path+\"predictions/predictions_\"+name+\".csv\")\n",
        "\n",
        "with open(path+\"reports/report_\"+name+\".txt\", 'w') as f:\n",
        "  f.write(content)\n"
      ],
      "metadata": {
        "id": "RQ0LYGuOt4A4",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bf597b6f-9c83-4fe3-c913-eed91c5610d6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "test_cnn_conv1D_egc_glove.6B.100d_s10000\n",
            "{'precision': 0.6623969961145181, 'recall': 0.6617672192787558, 'f1-score': 0.6552681646811063, 'support': 11702}\n",
            "0.6617672192787558\n",
            "                                      className  precision  ...    TP     TN\n",
            "0               Agriculture - Economie rustique   0.395210  ...    66  11368\n",
            "1                                      Anatomie   0.786207  ...   114  11456\n",
            "2                                     Antiquité   0.470383  ...   135  11278\n",
            "3                                  Architecture   0.543210  ...   132  11313\n",
            "4                               Arts et métiers   0.365854  ...    15  11564\n",
            "5                                    Beaux-arts   0.490566  ...    26  11589\n",
            "6                       Belles-lettres - Poésie   0.312500  ...    45  11397\n",
            "7                                        Blason   0.590909  ...    52  11558\n",
            "8                                    Caractères   1.000000  ...     2  11679\n",
            "9                                        Chasse   0.655914  ...    61  11554\n",
            "10                                       Chimie   0.400000  ...    32  11550\n",
            "11                                     Commerce   0.475410  ...   203  11102\n",
            "12                        Droit - Jurisprudence   0.785084  ...   979  10150\n",
            "13                          Economie domestique   0.200000  ...     1  11671\n",
            "14                                    Grammaire   0.434010  ...   171  11027\n",
            "15                                   Géographie   0.947469  ...  2471   8944\n",
            "16                                     Histoire   0.389087  ...   328  10571\n",
            "17                           Histoire naturelle   0.774162  ...   785  10510\n",
            "18                                          Jeu   0.631579  ...    36  11625\n",
            "19                                       Marine   0.647702  ...   296  11126\n",
            "20                           Maréchage - Manège   0.853933  ...    76  11584\n",
            "21                                Mathématiques   0.562914  ...    85  11496\n",
            "22                                       Mesure   0.333333  ...     2  11661\n",
            "23              Militaire (Art) - Guerre - Arme   0.569930  ...   163  11321\n",
            "24                                  Minéralogie   0.100000  ...     1  11671\n",
            "25                                      Monnaie   0.266667  ...     8  11617\n",
            "26                                      Musique   0.734694  ...    72  11539\n",
            "27                                    Médailles   0.857143  ...     6  11678\n",
            "28                         Médecine - Chirurgie   0.513228  ...   291  10971\n",
            "29                                      Métiers   0.527820  ...   702  10023\n",
            "30                                    Pharmacie   0.391304  ...     9  11623\n",
            "31                                  Philosophie   0.409836  ...    25  11572\n",
            "32  Physique - [Sciences physico-mathématiques]   0.522807  ...   149  11301\n",
            "33                                    Politique   0.500000  ...     1  11678\n",
            "34                                        Pêche   0.750000  ...    18  11654\n",
            "35                                     Religion   0.543807  ...   180  11223\n",
            "36                                    Spectacle   0.000000  ...     0  11693\n",
            "37                                 Superstition   0.857143  ...     6  11679\n",
            "\n",
            "[38 rows x 9 columns]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "Lbwg2H8sJRe7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "4mX5g55AJRhj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "-D8Gj6kzJRjv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "Lwz5cO2eJRmD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "yr_UWq14JRoI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "WJM6J6_EJRqx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "CtYR7NTvJRs2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "Qppm6jATJRvM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "rK5nK4gyJRx0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "vSjWwcQKJRz7"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}