Skip to content
Snippets Groups Projects
Classification_CNN.ipynb 93.7 KiB
Newer Older
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "EDdA-Classification_CNN_Conv1D-EGC.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0yFsoHXX8Iyy"
      },
      "source": [
        "# Deep learning for EDdA classification"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tFlUCDL2778i"
      },
      "source": [
        "## Setup colab environment"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Sp8d_Uus7SHJ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "976ed0dd-7aeb-4f64-e34b-117733abf38c"
      },
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jQBu-p6hBU-j"
      },
      "source": [
        "### Install packages"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bTIXsF6kBUdh"
      },
      "source": [
        "#!pip install zeugma\n",
        "#!pip install plot_model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "56-04SNF8BMx"
      },
      "source": [
        "### Import librairies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HwWkSznz7SEv"
      },
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import pickle\n",
        "import os\n",
        "\n",
        "from tqdm import tqdm\n",
        "import requests, zipfile, io\n",
        "import codecs\n",
        "\n",
        "from sklearn import preprocessing # LabelEncoder\n",
        "from sklearn.metrics import classification_report\n",
        "from sklearn.metrics import confusion_matrix\n",
        "\n",
        "from keras.preprocessing import sequence\n",
        "from keras.preprocessing.text import Tokenizer\n",
        "\n",
        "from keras.layers import BatchNormalization, Input, Reshape, Conv1D, MaxPool1D, Conv2D, MaxPool2D, Concatenate\n",
        "from keras.layers import Embedding, Dropout, Flatten, Dense\n",
        "from keras.models import Model, load_model\n",
        "from keras.callbacks import ModelCheckpoint\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xrekV6W978l4"
      },
      "source": [
        "### Utils functions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4LJ5blQR7PUe"
      },
      "source": [
        "\n",
        "def resample_classes(df, classColumnName, numberOfInstances):\n",
        "  #random numberOfInstances elements\n",
        "  replace = False  # with replacement\n",
        "  fn = lambda obj: obj.loc[np.random.choice(obj.index, numberOfInstances if len(obj) > numberOfInstances else len(obj), replace),:]\n",
        "  return df.groupby(classColumnName, as_index=False).apply(fn)\n",
        "    \n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MtLr35eM753e"
      },
      "source": [
        "## Load Data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FnbNT4NF7zal",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c2a72d94-c7ae-4e6a-b962-ec4677053555"
      },
      "source": [
        "!wget https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/training_set.tsv\n",
        "!wget https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/test_set.tsv"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2022-02-17 19:08:55--  https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/training_set.tsv\n",
            "Resolving projet.liris.cnrs.fr (projet.liris.cnrs.fr)... 134.214.142.28\n",
            "Connecting to projet.liris.cnrs.fr (projet.liris.cnrs.fr)|134.214.142.28|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 175634219 (167M) [text/tab-separated-values]\n",
            "Saving to: ‘training_set.tsv’\n",
            "\n",
            "training_set.tsv    100%[===================>] 167.50M  28.2MB/s    in 6.5s    \n",
            "\n",
            "2022-02-17 19:09:02 (25.7 MB/s) - ‘training_set.tsv’ saved [175634219/175634219]\n",
            "\n",
            "--2022-02-17 19:09:02--  https://projet.liris.cnrs.fr/geode/EDdA-Classification/datasets/test_set.tsv\n",
            "Resolving projet.liris.cnrs.fr (projet.liris.cnrs.fr)... 134.214.142.28\n",
            "Connecting to projet.liris.cnrs.fr (projet.liris.cnrs.fr)|134.214.142.28|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 42730598 (41M) [text/tab-separated-values]\n",
            "Saving to: ‘test_set.tsv’\n",
            "\n",
            "test_set.tsv        100%[===================>]  40.75M  19.7MB/s    in 2.1s    \n",
            "\n",
            "2022-02-17 19:09:05 (19.7 MB/s) - ‘test_set.tsv’ saved [42730598/42730598]\n",
            "\n"
          ]
        }
      ]
Loading
Loading full blame...