diff --git a/notebooks/Classification_BiLSTM.ipynb b/notebooks/Classification_BiLSTM.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..b10a1addfab6f000b6ed10161df20ceede21eddb --- /dev/null +++ b/notebooks/Classification_BiLSTM.ipynb @@ -0,0 +1,1893 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "EDdA-Classification_LSTM.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "0yFsoHXX8Iyy" + }, + "source": [ + "# Deep learning for EDdA classification" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tFlUCDL2778i" + }, + "source": [ + "## Setup colab environment" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Sp8d_Uus7SHJ", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f0dbf6d3-0d66-46b9-c00b-6fd788f4c689" + }, + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/drive\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jQBu-p6hBU-j" + }, + "source": [ + "### Install packages" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bTIXsF6kBUdh" + }, + "source": [ + "#!pip install zeugma\n", + "#!pip install plot_model" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "56-04SNF8BMx" + }, + "source": [ + "### Import librairies" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "HwWkSznz7SEv" + }, + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pickle\n", + "import os\n", + "\n", + "from tqdm import tqdm\n", + "import requests, zipfile, io\n", + "import codecs\n", + "\n", + "from sklearn import preprocessing # LabelEncoder\n", + "from sklearn.metrics import classification_report\n", + "from sklearn.metrics import confusion_matrix\n", + "\n", + "from keras.preprocessing import sequence\n", + "from keras.preprocessing.text import Tokenizer\n", + "\n", + "from keras.layers import BatchNormalization, Input, Reshape, Conv1D, MaxPool1D, Conv2D, MaxPool2D, Concatenate\n", + "from keras.layers import Embedding, Dropout, Flatten, Dense, LSTM, Bidirectional\n", + "from keras.models import Model, load_model\n", + "from keras.callbacks import ModelCheckpoint\n" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xrekV6W978l4" + }, + "source": [ + "### Utils functions" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4LJ5blQR7PUe" + }, + "source": [ + "\n", + "def resample_classes(df, classColumnName, numberOfInstances):\n", + " #random numberOfInstances elements\n", + " replace = False # with replacement\n", + " fn = lambda obj: obj.loc[np.random.choice(obj.index, numberOfInstances if len(obj) > numberOfInstances else len(obj), replace),:]\n", + " return df.groupby(classColumnName, as_index=False).apply(fn)\n", + " \n" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "-Rh3JMDh7zYd" + }, + "source": [ + "" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MtLr35eM753e" + }, + "source": [ + "## Load Data" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FnbNT4NF7zal", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "300e8fb1-43a5-4f25-9afa-15e04730b0f4" + }, + "source": [ + "!wget https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/training_set.tsv\n", + "!wget https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/test_set.tsv" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2022-02-18 07:32:19-- https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/training_set.tsv\n", + "Resolving projet.liris.cnrs.fr (projet.liris.cnrs.fr)... 134.214.142.28\n", + "Connecting to projet.liris.cnrs.fr (projet.liris.cnrs.fr)|134.214.142.28|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 175634219 (167M) [text/tab-separated-values]\n", + "Saving to: ‘training_set.tsv’\n", + "\n", + "training_set.tsv 100%[===================>] 167.50M 28.1MB/s in 6.6s \n", + "\n", + "2022-02-18 07:32:26 (25.5 MB/s) - ‘training_set.tsv’ saved [175634219/175634219]\n", + "\n", + "--2022-02-18 07:32:26-- https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/test_set.tsv\n", + "Resolving projet.liris.cnrs.fr (projet.liris.cnrs.fr)... 134.214.142.28\n", + "Connecting to projet.liris.cnrs.fr (projet.liris.cnrs.fr)|134.214.142.28|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 42730598 (41M) [text/tab-separated-values]\n", + "Saving to: ‘test_set.tsv’\n", + "\n", + "test_set.tsv 100%[===================>] 40.75M 19.6MB/s in 2.1s \n", + "\n", + "2022-02-18 07:32:29 (19.6 MB/s) - ‘test_set.tsv’ saved [42730598/42730598]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Loading dataset" + ], + "metadata": { + "id": "UHushJ1XfUj9" + } + }, + { + "cell_type": "code", + "source": [ + "train_path = 'training_set.tsv'\n", + "test_path = 'test_set.tsv'" + ], + "metadata": { + "id": "Q4te2c0bfvaJ" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "nRLaQUO97zcq" + }, + "source": [ + "df_train = pd.read_csv(train_path, sep=\"\\t\")\n", + "\n", + "#df_train = resample_classes(df_train, columnClass, maxOfInstancePerClass)\n" + ], + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "df_train.sample(5)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 600 + }, + "id": "2MvHEc7zVK1N", + "outputId": "57e4ff59-b74f-49ff-f441-5b4fba0973cf" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "\n", + " <div id=\"df-00b9a714-8d6b-461c-83fb-70e844fc260e\">\n", + " <div class=\"colab-df-container\">\n", + " <div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>volume</th>\n", + " <th>numero</th>\n", + " <th>head</th>\n", + " <th>normClass</th>\n", + " <th>classEDdA</th>\n", + " <th>author</th>\n", + " <th>id_enccre</th>\n", + " <th>domaine_enccre</th>\n", + " <th>ensemble_domaine_enccre</th>\n", + " <th>content</th>\n", + " <th>contentWithoutClass</th>\n", + " <th>firstParagraph</th>\n", + " <th>nb_words</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>31223</th>\n", + " <td>1</td>\n", + " <td>1494</td>\n", + " <td>AJUSTOIRE</td>\n", + " <td>La Monnoie</td>\n", + " <td>a la Monnoie.</td>\n", + " <td>unsigned</td>\n", + " <td>v1-1010-0</td>\n", + " <td>monnaie</td>\n", + " <td>Monnaie</td>\n", + " <td>AJUSTOIRE, s. m. (à la Monnoie.) est une balan...</td>\n", + " <td>ajustoire s. m. monnoie balance \\n sert ajuste...</td>\n", + " <td>ajustoire s. m. monnoie balance \\n sert ajuste...</td>\n", + " <td>85</td>\n", + " </tr>\n", + " <tr>\n", + " <th>710</th>\n", + " <td>14</td>\n", + " <td>719</td>\n", + " <td>REPAITRIR</td>\n", + " <td>Grammaire</td>\n", + " <td>Gram.</td>\n", + " <td>unsigned</td>\n", + " <td>v14-375-0</td>\n", + " <td>grammaire</td>\n", + " <td>Grammaire</td>\n", + " <td>REPAITRIR, v. act. (Gram.) paîtrir de-rechef.\\...</td>\n", + " <td>repaitrir vers act paîtrir de-rechef \\n arti...</td>\n", + " <td>repaitrir vers act paîtrir de-rechef \\n arti...</td>\n", + " <td>19</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4668</th>\n", + " <td>4</td>\n", + " <td>1838</td>\n", + " <td>Courir</td>\n", + " <td>Géographie</td>\n", + " <td>en Géographie</td>\n", + " <td>d'Alembert</td>\n", + " <td>v4-834-5</td>\n", + " <td>géographie</td>\n", + " <td>Géographie</td>\n", + " <td>Courir, se dit aussi en Géographie. Cette suit...</td>\n", + " <td>courir suite \\n montagne -on court est-ouest \\...</td>\n", + " <td>courir suite \\n montagne -on court est-ouest \\...</td>\n", + " <td>68</td>\n", + " </tr>\n", + " <tr>\n", + " <th>25846</th>\n", + " <td>1</td>\n", + " <td>2789</td>\n", + " <td>Anneaux</td>\n", + " <td>Manufacture en soie</td>\n", + " <td>manufactures en soie</td>\n", + " <td>unsigned</td>\n", + " <td>v1-2037-10</td>\n", + " <td>manufacture</td>\n", + " <td>Arts et métiers</td>\n", + " <td>Anneaux, s. m. pl. ce sont dans les manufactur...</td>\n", + " <td>anneau s. m. pl manufacture \\n soie petit cerc...</td>\n", + " <td>anneau s. m. pl manufacture \\n soie petit cerc...</td>\n", + " <td>214</td>\n", + " </tr>\n", + " <tr>\n", + " <th>21034</th>\n", + " <td>1</td>\n", + " <td>2118</td>\n", + " <td>AMBLEUR</td>\n", + " <td>Manège</td>\n", + " <td>Manege</td>\n", + " <td>Eidous</td>\n", + " <td>v1-1516-0</td>\n", + " <td>manège</td>\n", + " <td>Maréchage - Manège</td>\n", + " <td>AMBLEUR, s. m. (Manege.) Officier de la grande...</td>\n", + " <td>ambleur s. m. officier grand \\n petit écurie r...</td>\n", + " <td>ambleur s. m. officier grand \\n petit écurie r...</td>\n", + " <td>25</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>\n", + " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-00b9a714-8d6b-461c-83fb-70e844fc260e')\"\n", + " title=\"Convert this dataframe to an interactive table.\"\n", + " style=\"display:none;\">\n", + " \n", + " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\">\n", + " <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", + " <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", + " </svg>\n", + " </button>\n", + " \n", + " <style>\n", + " .colab-df-container {\n", + " display:flex;\n", + " flex-wrap:wrap;\n", + " gap: 12px;\n", + " }\n", + "\n", + " .colab-df-convert {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-convert:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + " </style>\n", + "\n", + " <script>\n", + " const buttonEl =\n", + " document.querySelector('#df-00b9a714-8d6b-461c-83fb-70e844fc260e button.colab-df-convert');\n", + " buttonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + "\n", + " async function convertToInteractive(key) {\n", + " const element = document.querySelector('#df-00b9a714-8d6b-461c-83fb-70e844fc260e');\n", + " const dataTable =\n", + " await google.colab.kernel.invokeFunction('convertToInteractive',\n", + " [key], {});\n", + " if (!dataTable) return;\n", + "\n", + " const docLinkHtml = 'Like what you see? Visit the ' +\n", + " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", + " + ' to learn more about interactive tables.';\n", + " element.innerHTML = '';\n", + " dataTable['output_type'] = 'display_data';\n", + " await google.colab.output.renderOutput(dataTable, element);\n", + " const docLink = document.createElement('div');\n", + " docLink.innerHTML = docLinkHtml;\n", + " element.appendChild(docLink);\n", + " }\n", + " </script>\n", + " </div>\n", + " </div>\n", + " " + ], + "text/plain": [ + " volume ... nb_words\n", + "31223 1 ... 85\n", + "710 14 ... 19\n", + "4668 4 ... 68\n", + "25846 1 ... 214\n", + "21034 1 ... 25\n", + "\n", + "[5 rows x 13 columns]" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Configuration\n" + ], + "metadata": { + "id": "-63bh_cKfN4p" + } + }, + { + "cell_type": "code", + "source": [ + "columnText = 'contentWithoutClass'\n", + "columnClass = 'ensemble_domaine_enccre'\n", + "\n", + "minOfInstancePerClass = 0\n", + "maxOfInstancePerClass = 1500\n", + "\n", + "batch_size = 64\n", + "validation_split = 0.20\n", + "#max_len = 512 # \n", + "max_nb_words = 20000 # taille du vocabulaire\n", + "max_sequence_length = 512 # taille max du 'document' \n", + "epochs = 10\n", + "\n", + "#embedding_name = \"fasttext\" \n", + "#embedding_dim = 300 \n", + "\n", + "embedding_name = \"glove.6B.100d\"\n", + "embedding_dim = 100 \n", + "\n", + "path = \"drive/MyDrive/Classification-EDdA/\"\n", + "encoder_filename = \"label_encoder.pkl\"\n", + "tokenizer_filename = \"tokenizer_keras.pkl\"" + ], + "metadata": { + "id": "nsRuyzYUfOBg" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Preprocessing\n" + ], + "metadata": { + "id": "ZDz-Y1LCfQt0" + } + }, + { + "cell_type": "code", + "source": [ + "if maxOfInstancePerClass != 10000:\n", + " df_train = resample_classes(df_train, columnClass, maxOfInstancePerClass)" + ], + "metadata": { + "id": "vw7UlS6-v0sN" + }, + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "vGWAgBH87ze8" + }, + "source": [ + "labels = df_train[columnClass]\n", + "numberOfClasses = labels.nunique()\n", + "\n", + "\n", + "if os.path.isfile(path+encoder_filename): \n", + " # load existing encoder \n", + " with open(path+encoder_filename, 'rb') as file:\n", + " encoder = pickle.load(file)\n", + "\n", + "else:\n", + " encoder = preprocessing.LabelEncoder()\n", + " encoder.fit(labels)\n", + "\n", + " with open(path+encoder_filename, 'wb') as file:\n", + " pickle.dump(encoder, file)\n", + "\n", + "\n", + "labels = encoder.transform(labels)" + ], + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "encoder.classes_" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SME4vvVhW9Sn", + "outputId": "00feb4ff-88b0-49b7-97ca-94219632371a" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array(['Agriculture - Economie rustique', 'Anatomie', 'Antiquité',\n", + " 'Architecture', 'Arts et métiers', 'Beaux-arts',\n", + " 'Belles-lettres - Poésie', 'Blason', 'Caractères', 'Chasse',\n", + " 'Chimie', 'Commerce', 'Droit - Jurisprudence',\n", + " 'Economie domestique', 'Grammaire', 'Géographie', 'Histoire',\n", + " 'Histoire naturelle', 'Jeu', 'Marine', 'Maréchage - Manège',\n", + " 'Mathématiques', 'Mesure', 'Militaire (Art) - Guerre - Arme',\n", + " 'Minéralogie', 'Monnaie', 'Musique', 'Médailles',\n", + " 'Médecine - Chirurgie', 'Métiers', 'Pharmacie', 'Philosophie',\n", + " 'Physique - [Sciences physico-mathématiques]', 'Politique',\n", + " 'Pêche', 'Religion', 'Spectacle', 'Superstition'], dtype=object)" + ] + }, + "metadata": {}, + "execution_count": 11 + } + ] + }, + { + "cell_type": "code", + "source": [ + "labels_index = dict(zip(list(encoder.classes_), encoder.transform(list(encoder.classes_))))" + ], + "metadata": { + "id": "nIzWQ2VbW_UO" + }, + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "labels_index" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4e7ggEGiXC_W", + "outputId": "e35c232e-cc24-47e7-aa26-0e5af7a570b4" + }, + "execution_count": 13, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'Agriculture - Economie rustique': 0,\n", + " 'Anatomie': 1,\n", + " 'Antiquité': 2,\n", + " 'Architecture': 3,\n", + " 'Arts et métiers': 4,\n", + " 'Beaux-arts': 5,\n", + " 'Belles-lettres - Poésie': 6,\n", + " 'Blason': 7,\n", + " 'Caractères': 8,\n", + " 'Chasse': 9,\n", + " 'Chimie': 10,\n", + " 'Commerce': 11,\n", + " 'Droit - Jurisprudence': 12,\n", + " 'Economie domestique': 13,\n", + " 'Grammaire': 14,\n", + " 'Géographie': 15,\n", + " 'Histoire': 16,\n", + " 'Histoire naturelle': 17,\n", + " 'Jeu': 18,\n", + " 'Marine': 19,\n", + " 'Maréchage - Manège': 20,\n", + " 'Mathématiques': 21,\n", + " 'Mesure': 22,\n", + " 'Militaire (Art) - Guerre - Arme': 23,\n", + " 'Minéralogie': 24,\n", + " 'Monnaie': 25,\n", + " 'Musique': 26,\n", + " 'Médailles': 27,\n", + " 'Médecine - Chirurgie': 28,\n", + " 'Métiers': 29,\n", + " 'Pharmacie': 30,\n", + " 'Philosophie': 31,\n", + " 'Physique - [Sciences physico-mathématiques]': 32,\n", + " 'Politique': 33,\n", + " 'Pêche': 34,\n", + " 'Religion': 35,\n", + " 'Spectacle': 36,\n", + " 'Superstition': 37}" + ] + }, + "metadata": {}, + "execution_count": 13 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Loading pre-trained embeddings\n", + "\n", + "#### FastText" + ], + "metadata": { + "id": "pguj5vY4vUnj" + } + }, + { + "cell_type": "code", + "source": [ + "# download FastText (prend trop de place pour le laisser sur le drive)\n", + "zip_file_url = \"https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip\"\n", + "r = requests.get(zip_file_url)\n", + "z = zipfile.ZipFile(io.BytesIO(r.content))\n", + "z.extractall()" + ], + "metadata": { + "id": "p25AssV4vUuj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print('loading word embeddings FastText...')\n", + "\n", + "embeddings_index = {}\n", + "f = codecs.open('crawl-300d-2M.vec', encoding='utf-8')\n", + "\n", + "for line in tqdm(f):\n", + " values = line.rstrip().rsplit(' ')\n", + " word = values[0]\n", + " coefs = np.asarray(values[1:], dtype='float32')\n", + " embeddings_index[word] = coefs\n", + "f.close()\n", + "\n", + "print('found %s word vectors' % len(embeddings_index))" + ], + "metadata": { + "id": "MYAljz1jvUxH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "#### GLOVE" + ], + "metadata": { + "id": "4KdmWdnhvU6A" + } + }, + { + "cell_type": "code", + "source": [ + "# download Glove\n", + "#zip_file_url = \"https://nlp.stanford.edu/data/glove.6B.zip\"\n", + "#r = requests.get(zip_file_url)\n", + "#z = zipfile.ZipFile(io.BytesIO(r.content))\n", + "#z.extractall()" + ], + "metadata": { + "id": "38rp4FSnvVBE" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print('loading word embeddings GLOVE...')\n", + "\n", + "embeddings_index = {}\n", + "f = open(path+\"embeddings/\"+embedding_name+\".txt\", encoding='utf-8')\n", + "for line in tqdm(f):\n", + " values = line.split()\n", + " word = values[0]\n", + " coefs = np.asarray(values[1:], dtype='float32')\n", + " embeddings_index[word] = coefs\n", + "f.close()\n", + "\n", + "print('Found %s word vectors.' % len(embeddings_index))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RVvIk8HdvVDZ", + "outputId": "5ca6c955-1764-42e9-c91c-4fab3953312d" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "loading word embeddings GLOVE...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "400000it [00:19, 20293.28it/s]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Found 400000 word vectors.\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HuUVfklf-dSR" + }, + "source": [ + "## Training models" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NTNh6kMTp_eU", + "outputId": "19b9945e-36ee-4ee7-ae45-43eb22122b93" + }, + "source": [ + "\n", + "raw_docs_train = df_train[columnText].tolist()\n", + "\n", + "\n", + "print(\"pre-processing train data...\")\n", + "\n", + "if os.path.isfile(path+tokenizer_filename):\n", + " with open(path+tokenizer_filename, 'rb') as file:\n", + " tokenizer = pickle.load(file)\n", + "else:\n", + " tokenizer = Tokenizer(num_words = max_nb_words)\n", + " tokenizer.fit_on_texts(raw_docs_train) \n", + "\n", + " with open(path+tokenizer_filename, 'wb') as file:\n", + " pickle.dump(tokenizer, file)\n", + "\n", + "sequences = tokenizer.texts_to_sequences(raw_docs_train)\n", + "\n", + "word_index = tokenizer.word_index\n", + "print(\"dictionary size: \", len(word_index))\n", + "\n", + "#pad sequences\n", + "data = sequence.pad_sequences(sequences, maxlen=max_sequence_length)\n", + "\n", + "print('Shape of data tensor:', data.shape)\n", + "print('Shape of label tensor:', labels.shape)\n", + "print(labels)" + ], + "execution_count": 15, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "pre-processing train data...\n", + "dictionary size: 190508\n", + "Shape of data tensor: (27381, 512)\n", + "Shape of label tensor: (27381,)\n", + "[ 0 0 0 ... 37 37 37]\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "data" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FtmB9Vgcaxgl", + "outputId": "4165d4d1-a778-4831-9342-9e48f62763bd" + }, + "execution_count": 16, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array([[ 0, 0, 0, ..., 147, 7021, 7021],\n", + " [ 0, 0, 0, ..., 613, 9556, 861],\n", + " [ 0, 0, 0, ..., 3733, 4758, 29],\n", + " ...,\n", + " [ 0, 0, 0, ..., 980, 3069, 347],\n", + " [ 0, 0, 0, ..., 2341, 231, 15988],\n", + " [ 0, 0, 0, ..., 4626, 27, 22]], dtype=int32)" + ] + }, + "metadata": {}, + "execution_count": 16 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# split the data into a training set and a validation set\n", + "\n", + "indices = np.arange(data.shape[0])\n", + "np.random.shuffle(indices)\n", + "data = data[indices]\n", + "labels = labels[indices]\n", + "\n", + "nb_validation_samples = int(validation_split * data.shape[0])\n", + "\n", + "x_train = data[:-nb_validation_samples]\n", + "y_train = labels[:-nb_validation_samples]\n", + "x_val = data[-nb_validation_samples:]\n", + "y_val = labels[-nb_validation_samples:]\n" + ], + "metadata": { + "id": "sHYJ4P-YDfFb" + }, + "execution_count": 17, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wGjQI0YgpQAS", + "outputId": "e155402c-f124-4619-fe43-49c4b5865df5" + }, + "source": [ + "#embedding matrix\n", + "\n", + "print('preparing embedding matrix...')\n", + "\n", + "embedding_matrix = np.zeros((len(word_index)+1, embedding_dim))\n", + "\n", + "for word, i in word_index.items():\n", + " embedding_vector = embeddings_index.get(word)\n", + " if embedding_vector is not None : \n", + " embedding_matrix[i] = embedding_vector\n" + ], + "execution_count": 18, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "preparing embedding matrix...\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "\n", + "#filter_sizes = [2, 3, 5]\n", + "drop = 0.2\n", + "\n", + "embedding_layer = Embedding(len(word_index)+1, embedding_dim, input_length = max_sequence_length,\n", + " weights=[embedding_matrix], trainable=False)\n", + "inputs = Input(shape=(max_sequence_length), dtype='int32')\n", + "embedding = embedding_layer(inputs)\n", + "\n", + "print(embedding.shape)\n", + "#reshape = Reshape((max_sequence_length, embedding_dim, 1))(embedding)\n", + "#print(reshape.shape)\n", + "\n", + "\n", + "#model = Sequential()\n", + "#model.add(Embedding(MAX_NB_WORDS, EMBEDDING_DIM, input_length=X.shape[1]))\n", + "#model.add(SpatialDropout1D(0.2))\n", + "#model.add(LSTM(100, dropout=0.2, recurrent_dropout=0.2))\n", + "#model.add(Dense(13, activation='softmax'))\n", + "#model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n", + "\n", + "\n", + "\n", + "#conv_0 = Conv1D(64, 5, activation='relu')(lstm)\n", + "#conv_1 = Conv1D(128, 5, activation='relu')(embedding)\n", + "#conv_2 = Conv1D(128, 5, activation='relu')(embedding)\n", + "\n", + "#maxpool_0 = MaxPool1D(pool_size=4)(conv_0)\n", + "\n", + "lstm = Bidirectional(LSTM(100))(embedding)\n", + "\n", + "#maxpool_1 = MaxPool1D(5)(conv_1)\n", + "#maxpool_2 = MaxPool1D(35)(conv_2)\n", + "\n", + "#concatenated_tensor = Concatenate(axis=1)([maxpool_0, maxpool_1, maxpool_2])\n", + "#flatten = Flatten()(maxpool_0)\n", + "dropout = Dropout(drop)(lstm)\n", + "output = Dense(len(labels_index), activation='softmax')(dropout)\n", + "\n", + "# this creates a model that includes\n", + "model = Model(inputs=inputs, outputs=output)\n", + "\n", + "checkpoint = ModelCheckpoint('weights_lstm_sentece.hdf5', monitor='val_acc', verbose=1, save_best_only=True, mode='auto')\n", + "#adam = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)\n", + "\n", + "model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])\n", + "model.summary()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OUphqlYJCC9n", + "outputId": "3fc0fc31-321c-46dd-a7c7-ff046bc5c819" + }, + "execution_count": 19, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(None, 512, 100)\n", + "Model: \"model\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_1 (InputLayer) [(None, 512)] 0 \n", + " \n", + " embedding (Embedding) (None, 512, 100) 19050900 \n", + " \n", + " bidirectional (Bidirectiona (None, 200) 160800 \n", + " l) \n", + " \n", + " dropout (Dropout) (None, 200) 0 \n", + " \n", + " dense (Dense) (None, 38) 7638 \n", + " \n", + "=================================================================\n", + "Total params: 19,219,338\n", + "Trainable params: 168,438\n", + "Non-trainable params: 19,050,900\n", + "_________________________________________________________________\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "history = model.fit(x_train, y_train, \n", + " batch_size=batch_size, \n", + " epochs=epochs, \n", + " verbose=1,\n", + " callbacks=[checkpoint],\n", + " validation_data=(x_val, y_val))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3aUBdLdNCEGK", + "outputId": "20203bf0-98ae-4a56-86cb-76dd8a36ac9f" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/10\n", + "343/343 [==============================] - ETA: 0s - loss: 2.7653 - acc: 0.2694\n", + "Epoch 1: val_acc improved from -inf to 0.38386, saving model to weights_lstm_sentece.hdf5\n", + "343/343 [==============================] - 388s 1s/step - loss: 2.7653 - acc: 0.2694 - val_loss: 2.3262 - val_acc: 0.3839\n", + "Epoch 2/10\n", + "343/343 [==============================] - ETA: 0s - loss: 2.2038 - acc: 0.4024\n", + "Epoch 2: val_acc improved from 0.38386 to 0.43006, saving model to weights_lstm_sentece.hdf5\n", + "343/343 [==============================] - 400s 1s/step - loss: 2.2038 - acc: 0.4024 - val_loss: 2.0739 - val_acc: 0.4301\n", + "Epoch 3/10\n", + "343/343 [==============================] - ETA: 0s - loss: 1.9875 - acc: 0.4516" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "ZcYbQsQEJKLq" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\n", + "plt.plot(history.history['acc'])\n", + "plt.plot(history.history['val_acc'])\n", + "plt.title('model accuracy')\n", + "plt.ylabel('accuracy')\n", + "plt.xlabel('epoch')\n", + "plt.legend(['train', 'validation'], loc='lower right')\n", + "plt.show()\n", + "\n", + "# summarize history for loss\n", + "plt.plot(history.history['loss'])\n", + "plt.plot(history.history['val_loss'])\n", + "plt.title('model loss')\n", + "plt.ylabel('loss')\n", + "plt.xlabel('epoch')\n", + "plt.legend(['train', 'validation'], loc='upper right')\n", + "plt.show()\n", + "\n" + ], + "metadata": { + "id": "Job-3uMvJKN_" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Uw6YR76p_AF0" + }, + "source": [ + "## Saving models" + ] + }, + { + "cell_type": "code", + "source": [ + "name = \"lstm_\"+embedding_name+\"_s\"+str(maxOfInstancePerClass)" + ], + "metadata": { + "id": "irB4mATeeUcw" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "ykTp9lyRaAma" + }, + "source": [ + "model.save(path+name+\".h5\")\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "5J4xDoqRUSfS" + }, + "source": [ + "# save embeddings\n", + "\n", + "# saving embeddings index \n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HHlEtipG_Cp0" + }, + "source": [ + "## Loading models" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fKt8ft1t_Cxx" + }, + "source": [ + "model = load_model(path+name+\".h5\")\n", + "\n", + "with open(path+tokenizer_filename, 'rb') as file:\n", + " tokenizer = pickle.load(file)\n", + "\n", + "with open(path+encoder_filename, 'rb') as file:\n", + " encoder = pickle.load(file)\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zbS4poso-3k7" + }, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "code", + "source": [ + "df_test = pd.read_csv(test_path, sep=\"\\t\")\n" + ], + "metadata": { + "id": "KWORvsadvBbr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "test_texts = df_test[columnText].tolist()\n", + "test_labels = df_test[columnClass].tolist()\n", + "\n", + "test_sequences = tokenizer.texts_to_sequences(test_texts)\n", + "test_input = sequence.pad_sequences(test_sequences, maxlen=max_sequence_length)\n", + "\n", + "# Get predictions\n", + "test_predictions_probas = model.predict(test_input)\n", + "test_predictions = test_predictions_probas.argmax(axis=-1)" + ], + "metadata": { + "id": "Xr0o-0i5t38G" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\n", + "\n", + "\n", + "test_intent_predictions = encoder.inverse_transform(test_predictions)\n", + "#test_intent_original = encoder.inverse_transform(test_labels)\n", + "\n", + "print('accuracy: ', sum(test_intent_predictions == test_labels) / len(test_labels))\n", + "print(\"Precision, Recall and F1-Score:\\n\\n\", classification_report(test_labels, test_intent_predictions))\n", + "\n" + ], + "metadata": { + "id": "lSn8yZ0gt3-d" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "name = \"test_\"+ name\n", + "\n", + "#classesName = encoder.classes_\n", + "#classes = [str(e) for e in encoder.transform(encoder.classes_)]\n", + "\n", + "report = classification_report(test_labels, test_intent_predictions, output_dict = True)\n", + "\n", + "precision = []\n", + "recall = []\n", + "f1 = []\n", + "support = []\n", + "dff = pd.DataFrame(columns= ['className', 'precision', 'recall', 'f1-score', 'support', 'FP', 'FN', 'TP', 'TN'])\n", + "for c in encoder.classes_:\n", + " precision.append(report[c]['precision'])\n", + " recall.append(report[c]['recall'])\n", + " f1.append(report[c]['f1-score'])\n", + " support.append(report[c]['support'])\n", + "\n", + "accuracy = report['accuracy']\n", + "weighted_avg = report['weighted avg']\n", + "\n", + "\n", + "cnf_matrix = confusion_matrix(test_labels, test_intent_predictions)\n", + "FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)\n", + "FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)\n", + "TP = np.diag(cnf_matrix)\n", + "TN = cnf_matrix.sum() - (FP + FN + TP)\n", + "\n", + "dff['className'] = encoder.classes_\n", + "dff['precision'] = precision\n", + "dff['recall'] = recall\n", + "dff['f1-score'] = f1\n", + "dff['support'] = support\n", + "dff['FP'] = FP\n", + "dff['FN'] = FN\n", + "dff['TP'] = TP\n", + "dff['TN'] = TN\n", + "\n", + "\n", + "\n", + " \n", + "content = name + \"\\n\"\n", + "print(name)\n", + "content += str(weighted_avg) + \"\\n\"\n", + "print(weighted_avg)\n", + "print(accuracy)\n", + "print(dff)\n", + "\n", + "dff.to_csv(path+\"reports/report_\"+name+\".csv\", index=False)\n", + "\n", + "# enregistrer les predictions\n", + "pd.DataFrame({'labels': pd.Series(df_test[columnClass]), 'predictions': pd.Series(test_intent_predictions)}).to_csv(path+\"predictions/predictions_\"+name+\".csv\")\n", + "\n", + "with open(path+\"reports/report_\"+name+\".txt\", 'w') as f:\n", + " f.write(content)\n" + ], + "metadata": { + "id": "RQ0LYGuOt4A4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "Lbwg2H8sJRe7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "4mX5g55AJRhj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "-D8Gj6kzJRjv" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "Lwz5cO2eJRmD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "yr_UWq14JRoI" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "WJM6J6_EJRqx" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "CtYR7NTvJRs2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "Qppm6jATJRvM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "rK5nK4gyJRx0" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "vSjWwcQKJRz7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "G9pjdMdNW_KS" + }, + "source": [ + "predictions = model.predict(word_seq_validation)\n", + "predictions = np.argmax(predictions,axis=1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IHpVJ79IW_M0", + "outputId": "2e1657b3-04d1-42f1-ea8b-9bbcd4744108" + }, + "source": [ + "report = classification_report(predictions, y_validation, output_dict = True)\n", + "\n", + "accuracy = report['accuracy']\n", + "weighted_avg = report['weighted avg']\n", + "\n", + "print(accuracy, weighted_avg)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.5726683109527725 {'precision': 0.6118028288513718, 'recall': 0.5726683109527725, 'f1-score': 0.5870482221489528, 'support': 10947}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n", + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n", + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "9SKjWffUW_PC" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "LpgkGq-fW_RN" + }, + "source": [ + "df_test = pd.read_csv(test_path, sep=\"\\t\")\n", + "\n", + "encoder = preprocessing.LabelEncoder()\n", + "y_test = encoder.fit_transform(df_test[columnClass])\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Q9eYqi5SW_Ta", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "31e45f20-583a-4ca6-eac8-21863f6fef5b" + }, + "source": [ + "raw_docs_test = df_test[columnText].tolist()\n", + "\n", + "print(\"pre-processing test data...\")\n", + "\n", + "stop_words = set(stopwords.words('french'))\n", + "\n", + "processed_docs_test = []\n", + "for doc in tqdm(raw_docs_test):\n", + " tokens = word_tokenize(doc, language='french')\n", + " filtered = [word for word in tokens if word not in stop_words]\n", + " processed_docs_test.append(\" \".join(filtered))\n", + "#end for\n", + "\n", + "print(\"tokenizing input data...\")\n", + "#tokenizer = Tokenizer(num_words=max_len, lower=True, char_level=False)\n", + "#tokenizer.fit_on_texts(processed_docs_train + processed_docs_validation) #leaky\n", + "word_seq_test = tokenizer.texts_to_sequences(processed_docs_test)\n", + "\n", + "#pad sequences\n", + "word_seq_test = sequence.pad_sequences(word_seq_test, maxlen=max_len)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "pre-processing test data...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 13137/13137 [00:09<00:00, 1331.48it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tokenizing input data...\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_WjpJN-Bqjeb" + }, + "source": [ + "predictions = model.predict(word_seq_test)\n", + "predictions = np.argmax(predictions,axis=1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "zUwjL_dQqjgx", + "outputId": "912642ad-95eb-413a-d074-8d4881a57359" + }, + "source": [ + "report = classification_report(predictions, y_test, output_dict = True)\n", + "\n", + "accuracy = report['accuracy']\n", + "weighted_avg = report['weighted avg']\n", + "\n", + "print(accuracy, weighted_avg)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.5698409073608891 {'precision': 0.6081680700148677, 'recall': 0.5698409073608891, 'f1-score': 0.5847417616022411, 'support': 13137}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n", + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n", + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ka6DcPe7qqvg", + "outputId": "0c8cfbe6-178d-4208-98ba-4ba688e32939" + }, + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "\n", + "classesName = encoder.classes_\n", + "classes = [str(e) for e in encoder.transform(encoder.classes_)]\n", + "\n", + "precision = []\n", + "recall = []\n", + "f1 = []\n", + "support = []\n", + "dff = pd.DataFrame(columns= ['className', 'precision', 'recall', 'f1-score', 'support', 'FP', 'FN', 'TP', 'TN'])\n", + "for c in classes:\n", + " precision.append(report[c]['precision'])\n", + " recall.append(report[c]['recall'])\n", + " f1.append(report[c]['f1-score'])\n", + " support.append(report[c]['support'])\n", + "\n", + "accuracy = report['accuracy']\n", + "weighted_avg = report['weighted avg']\n", + "\n", + "\n", + "cnf_matrix = confusion_matrix(y_test, predictions)\n", + "FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)\n", + "FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)\n", + "TP = np.diag(cnf_matrix)\n", + "TN = cnf_matrix.sum() - (FP + FN + TP)\n", + "\n", + "dff['className'] = classesName\n", + "dff['precision'] = precision\n", + "dff['recall'] = recall\n", + "dff['f1-score'] = f1\n", + "dff['support'] = support\n", + "dff['FP'] = FP\n", + "dff['FN'] = FN\n", + "dff['TP'] = TP\n", + "dff['TN'] = TN\n", + "\n", + "print(\"test_cnn_s\"+str(maxOfInstancePerClass))\n", + "\n", + "print(weighted_avg)\n", + "print(accuracy)\n", + "print(dff)\n", + "\n", + "dff.to_csv(\"drive/MyDrive/Classification-EDdA/report_test_cnn_s\"+str(maxOfInstancePerClass)+\".csv\", index=False)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "test_cnn_s10000\n", + "{'precision': 0.6081680700148677, 'recall': 0.5698409073608891, 'f1-score': 0.5847417616022411, 'support': 13137}\n", + "0.5698409073608891\n", + " className precision ... TP TN\n", + "0 Agriculture - Economie rustique 0.216535 ... 55 12636\n", + "1 Anatomie 0.459821 ... 103 12768\n", + "2 Antiquité 0.287975 ... 91 12710\n", + "3 Architecture 0.339623 ... 108 12722\n", + "4 Arts et métiers 0.015504 ... 2 12995\n", + "5 Beaux-arts 0.060000 ... 6 13018\n", + "6 Belles-lettres - Poésie 0.127660 ... 30 12761\n", + "7 Blason 0.228571 ... 24 12993\n", + "8 Caractères 0.037037 ... 1 13110\n", + "9 Chasse 0.221311 ... 27 12962\n", + "10 Chimie 0.160714 ... 18 12991\n", + "11 Commerce 0.443418 ... 192 12490\n", + "12 Droit - Jurisprudence 0.762879 ... 1081 11263\n", + "13 Economie domestique 0.000000 ... 0 13102\n", + "14 Grammaire 0.408929 ... 229 12254\n", + "15 Géographie 0.917312 ... 2607 9910\n", + "16 Histoire 0.405063 ... 288 11777\n", + "17 Histoire naturelle 0.743292 ... 831 11661\n", + "18 Jeu 0.061538 ... 4 13067\n", + "19 Marine 0.590805 ... 257 12549\n", + "20 Maréchage - Manège 0.620690 ... 72 13001\n", + "21 Mathématiques 0.549669 ... 83 12903\n", + "22 Mesure 0.095238 ... 4 13087\n", + "23 Militaire (Art) - Guerre - Arme 0.476351 ... 141 12704\n", + "24 Minéralogie 0.000000 ... 0 13111\n", + "25 Monnaie 0.054795 ... 4 13051\n", + "26 Musique 0.287500 ... 46 12904\n", + "27 Médailles 0.000000 ... 0 13107\n", + "28 Médecine - Chirurgie 0.376218 ... 193 12149\n", + "29 Métiers 0.605634 ... 731 11047\n", + "30 Pharmacie 0.070423 ... 5 13045\n", + "31 Philosophie 0.071429 ... 8 12996\n", + "32 Physique - [Sciences physico-mathématiques] 0.378378 ... 112 12674\n", + "33 Politique 0.000000 ... 0 13110\n", + "34 Pêche 0.170213 ... 8 13069\n", + "35 Religion 0.326371 ... 125 12488\n", + "36 Spectacle 0.000000 ... 0 13121\n", + "37 Superstition 0.000000 ... 0 13112\n", + "\n", + "[38 rows x 9 columns]\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "BqJ1_hUUqqx5" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "bhfuGNwIqrOQ" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "NkL3MopyqrQk" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "XLHl-pvzqjjI" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "lLR_Xvi9qjlo" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "8cGcLOFTqjoP" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "vLGTnit_W_V8" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "R-3lBXjDD9wE" + }, + "source": [ + "def predict(data, max_len):\n", + " \n", + " pad_sequ_test, _ = prepare_sequence(data, max_len)\n", + " pred_labels_ = model.predict(pad_sequ_test)\n", + "\n", + " return np.argmax(pred_labels_,axis=1)\n", + "\n", + "\n", + "def eval(data, labels, max_len):\n", + " \n", + " pred_labels_ = predict(data, max_len)\n", + " report = classification_report(pred_labels_, labels, output_dict = True)\n", + "\n", + " accuracy = report['accuracy']\n", + " weighted_avg = report['weighted avg']\n", + " \n", + " print(accuracy, weighted_avg)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6T3kAvKvExgc", + "outputId": "c6d4560e-fc64-4579-9adb-79c2e36d2386" + }, + "source": [ + "# evaluation sur le jeu de validation\n", + "eval(df_validation[columnText], y_validation, max_len)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/zeugma/keras_transformers.py:33: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " return np.array(self.texts_to_sequences(texts))\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.06925290207361841 {'precision': 0.09108131158125257, 'recall': 0.06925290207361841, 'f1-score': 0.06099084715237025, 'support': 10079}\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pTDJA03_-8yu", + "outputId": "d8bcdf73-c4c3-4c88-b063-90bd1cad5122" + }, + "source": [ + "# evaluation sur le jeu de test\n", + "df_test = pd.read_csv(test_path, sep=\"\\t\")\n", + "#df_test = resample_classes(df_test, columnClass, maxOfInstancePerClass)\n", + "\n", + "y_test = df_test[columnClass]\n", + "encoder = preprocessing.LabelEncoder()\n", + "y_test = encoder.fit_transform(y_test)\n", + "\n", + "eval(df_test[columnText], y_test, max_len)\n" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/zeugma/keras_transformers.py:33: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " return np.array(self.texts_to_sequences(texts))\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.07231483595950369 {'precision': 0.081194635559303, 'recall': 0.07231483595950369, 'f1-score': 0.06322383877903374, 'support': 13137}\n" + ] + } + ] + } + ] +} \ No newline at end of file