{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4YCMlsNwOWs0"
      },
      "source": [
        "# BERT fine-tuning for EDdA classification"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6xdYI9moOQSv"
      },
      "source": [
        "## Setup colab environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WF0qFN_g3ekz",
        "outputId": "f3a5f049-24ee-418f-fe5e-84c633234ad8"
      },
      "outputs": [],
      "source": [
        "from psutil import virtual_memory\n",
        "ram_gb = virtual_memory().total / 1e9\n",
        "print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n",
        "\n",
        "if ram_gb < 20:\n",
        "  print('Not using a high-RAM runtime')\n",
        "else:\n",
        "  print('You are using a high-RAM runtime!')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vL0S-s9Uofvn",
        "outputId": "4b7efa4d-7f09-4c8e-bc98-99e6099ede32"
      },
      "outputs": [],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8hzEGHl7gmzk"
      },
      "source": [
        "## Setup GPU"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dPOU-Efhf4ui",
        "outputId": "121dd21e-f98c-483d-d6d1-2838f732a4e2"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "# If there's a GPU available...\n",
        "if torch.cuda.is_available():    \n",
        "    # Tell PyTorch to use the GPU.    \n",
        "    device = torch.device(\"cuda\")\n",
        "    print('There are %d GPU(s) available.' % torch.cuda.device_count())\n",
        "    print('We will use the GPU:', torch.cuda.get_device_name(0))\n",
        "\n",
        "# for MacOS\n",
        "elif torch.backends.mps.is_available() and torch.backends.mps.is_built():\n",
        "    device = torch.device(\"mps\")\n",
        "    print('We will use the GPU')\n",
        "else:\n",
        "    device = torch.device(\"cpu\")\n",
        "    print('No GPU available, using the CPU instead.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jr-S9yYIgGkA"
      },
      "source": [
        "## Install packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "pwmZ5bBvgGNh",
        "outputId": "fce0a8bf-1779-4079-c7ac-200ebb2678c5"
      },
      "outputs": [],
      "source": [
        "!pip install transformers==4.10.3\n",
        "!pip install sentencepiece"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wSqbrupGMc1M"
      },
      "source": [
        "## Import librairies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SkErnwgMMbRj"
      },
      "outputs": [],
      "source": [
        "import pandas as pd \n",
        "import numpy as np\n",
        "import csv\n",
        "import os\n",
        "import pickle\n",
        "from sklearn import preprocessing\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.metrics import *\n",
        "\n",
        "from transformers import BertTokenizer, CamembertTokenizer, BertForSequenceClassification, AdamW, BertConfig, CamembertForSequenceClassification\n",
        "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n",
        "from transformers import get_linear_schedule_with_warmup\n",
        "\n",
        "import time\n",
        "import datetime\n",
        "\n",
        "import random\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.metrics import plot_confusion_matrix\n",
        "from sklearn.metrics import confusion_matrix\n",
        "from sklearn.metrics import classification_report\n",
        "import seaborn as sns"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "12SA-qPFgsVo"
      },
      "source": [
        "## Utils functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WkIVcabUgxIl"
      },
      "outputs": [],
      "source": [
        "def resample_classes(df, classColumnName, numberOfInstances):\n",
        "  #random numberOfInstances elements\n",
        "  replace = False  # with replacement\n",
        "  fn = lambda obj: obj.loc[np.random.choice(obj.index, numberOfInstances if len(obj) > numberOfInstances else len(obj), replace),:]\n",
        "  return df.groupby(classColumnName, as_index=False).apply(fn)\n",
        "    \n",
        "# Function to calculate the accuracy of our predictions vs labels\n",
        "def flat_accuracy(preds, labels):\n",
        "  pred_flat = np.argmax(preds, axis=1).flatten()\n",
        "  labels_flat = labels.flatten()\n",
        "  return np.sum(pred_flat == labels_flat) / len(labels_flat) \n",
        "\n",
        "def format_time(elapsed):\n",
        "  '''\n",
        "  Takes a time in seconds and returns a string hh:mm:ss\n",
        "  '''\n",
        "  # Round to the nearest second.\n",
        "  elapsed_rounded = int(round((elapsed)))\n",
        "\n",
        "  # Format as hh:mm:ss\n",
        "  return str(datetime.timedelta(seconds=elapsed_rounded))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c5QKcXulhNJ-"
      },
      "source": [
        "## Load Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jdCdUVOTZrqh",
        "outputId": "ac52be55-ed4b-4c50-dc8c-9124ca6b71e5"
      },
      "outputs": [],
      "source": [
        "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/training_set.tsv\n",
        "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/test_set.tsv"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/training_set_superdomains.tsv\n",
        "!wget https://geode.liris.cnrs.fr/EDdA-Classification/datasets/test_set_superdomains.tsv"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9d1IxD_bLEvp"
      },
      "source": [
        "### Loading dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7JEnKknRoClH"
      },
      "outputs": [],
      "source": [
        "#train_path = '../data/training_set.tsv'\n",
        "#test_path =  '../data/test_set.tsv'\n",
        "\n",
        "train_path = '../data/training_set_superdomains.tsv'\n",
        "test_path =  '../data/test_set_superdomains.tsv'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 496
        },
        "id": "5u1acjunhoxe",
        "outputId": "3038048d-6506-473d-85c9-2d3ebdcc6a72"
      },
      "outputs": [],
      "source": [
        "df_train = pd.read_csv(train_path, sep=\"\\t\")\n",
        "df_train.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zj3JDoJNfx1f",
        "outputId": "f1ec1fcf-b287-460a-8110-dbb00091c203"
      },
      "outputs": [],
      "source": [
        "print(df_train.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dADYGtTcn4AB"
      },
      "source": [
        "## Configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I0OrfFsBn4Io"
      },
      "outputs": [],
      "source": [
        "columnText = 'contentWithoutClass'\n",
        "#columnClass = 'ensemble_domaine_enccre'\n",
        "columnClass = 'super_domain'\n",
        "\n",
        "maxOfInstancePerClass = 10000\n",
        "\n",
        "model_chosen = \"bert\"\n",
        "#model_chosen = \"camembert\"\n",
        "\n",
        "batch_size = 16  # 16 or 32 recommended\n",
        "max_len = 512\n",
        "\n",
        "#path = \"drive/MyDrive/Classification-EDdA/\"\n",
        "path = \"../models/new/\"\n",
        "encoder_filename = \"label_encoder.pkl\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t3brU-Yvn4XS"
      },
      "source": [
        "## Preprocessing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aQCLJE4Jtm7v"
      },
      "outputs": [],
      "source": [
        "if maxOfInstancePerClass != 10000:\n",
        "  df_train = resample_classes(df_train, columnClass, maxOfInstancePerClass)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zrjZvs2dhzAy"
      },
      "outputs": [],
      "source": [
        "labels  = df_train[columnClass]\n",
        "numberOfClasses = labels.nunique()\n",
        "\n",
        "\n",
        "if os.path.isfile(path+encoder_filename):    \n",
        "    # load existing encoder \n",
        "    with open(path+encoder_filename, 'rb') as file:\n",
        "      encoder = pickle.load(file)\n",
        "\n",
        "else:\n",
        "  encoder = preprocessing.LabelEncoder()\n",
        "  encoder.fit(labels)\n",
        "\n",
        "  with open(path+encoder_filename, 'wb') as file:\n",
        "      pickle.dump(encoder, file)\n",
        "\n",
        "\n",
        "labels = encoder.transform(labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xt_PhH_6h1_3"
      },
      "outputs": [],
      "source": [
        "sentences_train = df_train[columnText].values\n",
        "labels_train = labels.tolist()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Dq_KF5WAsbpC",
        "outputId": "7925ce5a-4b9e-4147-fdc1-f2916d0e6600"
      },
      "outputs": [],
      "source": [
        "sentences_train"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gs4Agx_5h43M"
      },
      "source": [
        "# Model\n",
        "## Tokenisation & Input Formatting"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YZ5PhEYZiCEA"
      },
      "outputs": [],
      "source": [
        "if model_chosen == \"bert\":\n",
        "  tokeniser_bert = 'bert-base-multilingual-cased'\n",
        "  model_bert =  \"bert-base-multilingual-cased\"\n",
        "elif model_chosen == \"camembert\":\n",
        "  tokeniser_bert = 'camembert-base'\n",
        "  model_bert = 'camembert-base'\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 131,
          "referenced_widgets": [
            "274e505b5f354efc8de3ef26cc43e617",
            "f1f9d5b32f60473b86ae6b340d6c0850",
            "ad5e0e1439a94578a31b80c90dbf3247",
            "0779c8ea0ed24e64a800ae5dff6bc8ce",
            "7870340ac12b469c8ac19de3a47b6e67",
            "5f321455342348f49879a9ca8b392077",
            "9420a47a2bf44ead8cff283f20566cda",
            "99b785ea53744868b8b11e5e94936fcc",
            "8d24b669a39b4876ac0a014dff678db1",
            "2cf386a8d14d43389374f79bfa922675",
            "2c44d9c11e704b70aa32904a23d1790c",
            "0279837673b446b09aac18346213eb7e",
            "09b5f0bbd5c14bc289b0f92a22bb29ab",
            "69004a5069094f8c9d59d5136f627bef",
            "e96a95317b0945c58c8ff0e944c7593e",
            "68b69c9d3a274900bc2892848f725cb0",
            "76007b17ffd2478fa4a86f959d4f1766",
            "cb447c62ce1d4c1ea760175ae619fbb9",
            "d4ad1a78750d49feaea584a82940bb7d",
            "a9c47cb226ee41e18812f29f690992eb",
            "c4c1675163bd4997bb44d7ea3967a356",
            "5032547e748f45a3b0cdd12fafe1dd05",
            "8f467553598f4dcc9abf55da79c11018",
            "9d7a8b3ecfe74f66b4238fe085c05906",
            "58b4f9e0366f4d4eba7f902af84b8965",
            "9383e09698ae4bd1820a4bca22e78315",
            "a189838c4de648198b0f4fc99c29ced8",
            "c4d981755d1d42b6940396b77bc251bc",
            "12afa6b6474b401f9ff3f189cc0d3d11",
            "5978954f56fb40928b970f32d1634aaf",
            "fe0e3b1df104484c98fbdcd31a04e427",
            "2d1d632da0f740c38512c9ad779d3173",
            "df95c20399dd4918bc7559a90886d4aa"
          ]
        },
        "id": "C4bigx_3ibuN",
        "outputId": "ebcca5ee-85d8-4525-c9ad-9fc3b5c1741d"
      },
      "outputs": [],
      "source": [
        "# Load the BERT tokenizer.\n",
        "if model_chosen == \"bert\":\n",
        "  print('Loading BERT tokenizer...')\n",
        "  tokenizer = BertTokenizer.from_pretrained(tokeniser_bert)\n",
        "elif model_chosen == \"camembert\":\n",
        "  print('Loading CamemBERT tokenizer...')\n",
        "  tokenizer = CamembertTokenizer.from_pretrained(tokeniser_bert)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5hNod5X9jDZN",
        "outputId": "bca0db0e-7463-40cd-8052-1712965c7a95"
      },
      "outputs": [],
      "source": [
        " # Tokenize all of the sentences and map the tokens to thier word IDs.\n",
        "input_ids_train = []\n",
        "\n",
        "# For every sentence...\n",
        "for sent in sentences_train:\n",
        "    # `encode` will:\n",
        "    #   (1) Tokenize the sentence.\n",
        "    #   (2) Prepend the `[CLS]` token to the start.\n",
        "    #   (3) Append the `[SEP]` token to the end.\n",
        "    #   (4) Map tokens to their IDs.\n",
        "    encoded_sent_train = tokenizer.encode(\n",
        "                        str(sent),                      # Sentence to encode.\n",
        "                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
        "\n",
        "                        # This function also supports truncation and conversion\n",
        "                        # to pytorch tensors, but I need to do padding, so I\n",
        "                        # can't use these features.\n",
        "                        #max_length = 128,          # Truncate all sentences.\n",
        "                        #return_tensors = 'pt',     # Return pytorch tensors.\n",
        "                   )\n",
        "    \n",
        "    # Add the encoded sentence to the list.\n",
        "    input_ids_train.append(encoded_sent_train)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W9EWv5JvjGH3",
        "outputId": "dde87708-7bcb-47c7-af71-2ec2b2e0c2db"
      },
      "outputs": [],
      "source": [
        "print('Max sentence length train: ', max([len(sen) for sen in input_ids_train]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xh1TQJyvjOx5"
      },
      "outputs": [],
      "source": [
        "\n",
        "padded_train = []\n",
        "for i in input_ids_train:\n",
        "\n",
        "  if len(i) > max_len:\n",
        "    padded_train.extend([i[:max_len]])\n",
        "  else:\n",
        "    padded_train.extend([i + [0] * (max_len - len(i))])\n",
        "\n",
        "\n",
        "padded_train = input_ids_train = np.array(padded_train)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZiwY6gn0jUkD"
      },
      "outputs": [],
      "source": [
        " # Create attention masks\n",
        "attention_masks_train = []\n",
        "\n",
        "# For each sentence...\n",
        "for sent in padded_train:\n",
        "    \n",
        "    # Create the attention mask.\n",
        "    #   - If a token ID is 0, then it's padding, set the mask to 0.\n",
        "    #   - If a token ID is > 0, then it's a real token, set the mask to 1.\n",
        "    att_mask = [int(token_id > 0) for token_id in sent]\n",
        "    \n",
        "    # Store the attention mask for this sentence.\n",
        "    attention_masks_train.append(att_mask)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oBTR5AfAjXJe"
      },
      "outputs": [],
      "source": [
        "# Use 70% for training and 30% for validation.\n",
        "#train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(padded, labels, \n",
        "#                                                            random_state=2018, test_size=0.3, stratify = labels)\n",
        "# Do the same for the masks.\n",
        "#train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels,\n",
        "#                                             random_state=2018, test_size=0.3, stratify = labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b9Mw5kq3jhTb"
      },
      "outputs": [],
      "source": [
        "# Convert all inputs and labels into torch tensors, the required datatype \n",
        "# for my model.\n",
        "train_inputs = torch.tensor(padded_train)\n",
        "\n",
        "train_labels = torch.tensor(labels_train)\n",
        "\n",
        "train_masks = torch.tensor(attention_masks_train)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UfFWzbENjnkw"
      },
      "outputs": [],
      "source": [
        "# The DataLoader needs to know the batch size for training, so I specify it here.\n",
        "# For fine-tuning BERT on a specific task, the authors recommend a batch size of\n",
        "# 16 or 32.\n",
        "\n",
        "# Create the DataLoader for training set.\n",
        "train_data = TensorDataset(train_inputs, train_masks, train_labels)\n",
        "train_sampler = RandomSampler(train_data)\n",
        "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x45JNGqhkUn2"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "d09d664839d04303b8fef9ef895f6e4f",
            "500826e3813b414a820aa260bfde9e23",
            "70dd7428d78c44409308d62ba04917de",
            "152a31110bf9477989833eac91794688",
            "fcde5f4cf49846a0ad1b284aad98a38a",
            "1bf6a76237454349aafc1e9284376879",
            "4a23110523184d019a77368116f738f3",
            "e86a1d4d268c4314897b58f7bba5ec25",
            "826bd7d0a1f146ea9f7d53584468190c",
            "3592b1ed1d6d452b93c57b304943a1cb",
            "a159d62667734657a49ba3a96494f137"
          ]
        },
        "id": "C7M2Er1ajsTf",
        "outputId": "151034cd-9a77-413e-a61e-561c97b4072e"
      },
      "outputs": [],
      "source": [
        "# Load BertForSequenceClassification, the pretrained BERT model with a single \n",
        "# linear classification layer on top.\n",
        "\n",
        "#model = CamembertForSequenceClassification.from_pretrained(\n",
        "if model_chosen == \"bert\":\n",
        "  model = BertForSequenceClassification.from_pretrained(\n",
        "      model_bert, # Use the 12-layer BERT model, with an uncased vocab.\n",
        "      num_labels = numberOfClasses, # The number of output labels--2 for binary classification.\n",
        "                      # You can increase this for multi-class tasks.   \n",
        "      output_attentions = False, # Whether the model returns attentions weights.\n",
        "      output_hidden_states = False, # Whether the model returns all hidden-states.\n",
        "  )\n",
        "elif model_chosen == \"camembert\":\n",
        "  model = CamembertForSequenceClassification.from_pretrained(\n",
        "      model_bert, # Use the 12-layer BERT model, with an uncased vocab.\n",
        "      num_labels = numberOfClasses, # The number of output labels--2 for binary classification.\n",
        "                      # You can increase this for multi-class tasks.   \n",
        "      output_attentions = False, # Whether the model returns attentions weights.\n",
        "      output_hidden_states = False, # Whether the model returns all hidden-states.\n",
        "  )\n",
        "\n",
        "# Tell pytorch to run this model on the GPU.\n",
        "#model.cuda()\n",
        "model.to(\"mps\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xd_cG-8pj4Iw"
      },
      "outputs": [],
      "source": [
        "#Note: AdamW is a class from the huggingface library (as opposed to pytorch) \n",
        "# I believe the 'W' stands for 'Weight Decay fix\"\n",
        "optimizer = AdamW(model.parameters(),\n",
        "                  lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5\n",
        "                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8.\n",
        "                )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "65G-uHuLj4_6"
      },
      "outputs": [],
      "source": [
        "# Number of training epochs (authors recommend between 2 and 4)\n",
        "epochs = 4\n",
        "\n",
        "# Total number of training steps is number of batches * number of epochs.\n",
        "total_steps = len(train_dataloader) * epochs\n",
        "\n",
        "# Create the learning rate scheduler.\n",
        "scheduler = get_linear_schedule_with_warmup(optimizer, \n",
        "                                            num_warmup_steps = 0, # Default value in run_glue.py\n",
        "                                            num_training_steps = total_steps)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SbHBbYpwkKaA",
        "outputId": "4cd1be4a-6014-4804-df56-f38e98039797"
      },
      "outputs": [],
      "source": [
        "# This training code is based on the `run_glue.py` script here:\n",
        "# https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128\n",
        "\n",
        "# Set the seed value all over the place to make this reproducible.\n",
        "seed_val = 42\n",
        "\n",
        "random.seed(seed_val)\n",
        "np.random.seed(seed_val)\n",
        "torch.manual_seed(seed_val)\n",
        "torch.cuda.manual_seed_all(seed_val)\n",
        "\n",
        "# Store the average loss after each epoch so I can plot them.\n",
        "loss_values = []\n",
        "\n",
        "# For each epoch...\n",
        "for epoch_i in range(0, epochs):\n",
        "    \n",
        "    # ========================================\n",
        "    #               Training\n",
        "    # ========================================\n",
        "    \n",
        "    # Perform one full pass over the training set.\n",
        "\n",
        "    print(\"\")\n",
        "    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))\n",
        "    print('Training...')\n",
        "\n",
        "    # Measure how long the training epoch takes.\n",
        "    t0 = time.time()\n",
        "\n",
        "    # Reset the total loss for this epoch.\n",
        "    total_loss = 0\n",
        "\n",
        "    # Put the model into training mode.\n",
        "    model.train()\n",
        "\n",
        "    # For each batch of training data...\n",
        "    for step, batch in enumerate(train_dataloader):\n",
        "\n",
        "        # Progress update every 40 batches.\n",
        "        if step % 5 == 0 and not step == 0:\n",
        "            # Calculate elapsed time in minutes.\n",
        "            elapsed = format_time(time.time() - t0)\n",
        "            \n",
        "            # Report progress.\n",
        "            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))\n",
        "\n",
        "        # Unpack this training batch from the dataloader. \n",
        "        #\n",
        "        # As I unpack the batch, I'll also copy each tensor to the GPU using the \n",
        "        # `to` method.\n",
        "        #\n",
        "        # `batch` contains three pytorch tensors:\n",
        "        #   [0]: input ids \n",
        "        #   [1]: attention masks\n",
        "        #   [2]: labels \n",
        "        b_input_ids = batch[0].to(device)\n",
        "        b_input_mask = batch[1].to(device)\n",
        "        b_labels = batch[2].to(device)\n",
        "\n",
        "        # Always clear any previously calculated gradients before performing a\n",
        "        # backward pass. PyTorch doesn't do this automatically because \n",
        "        # accumulating the gradients is \"convenient while training RNNs\". \n",
        "        # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)\n",
        "        model.zero_grad()        \n",
        "\n",
        "        # Perform a forward pass (evaluate the model on this training batch).\n",
        "        # This will return the loss (rather than the model output) because I\n",
        "        # have provided the `labels`.\n",
        "        # The documentation for this `model` function is here: \n",
        "        # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification\n",
        "        outputs = model(b_input_ids, \n",
        "                    token_type_ids=None, \n",
        "                    attention_mask=b_input_mask, \n",
        "                    labels=b_labels)\n",
        "        \n",
        "        # The call to `model` always returns a tuple, so I need to pull the \n",
        "        # loss value out of the tuple.\n",
        "        loss = outputs[0]\n",
        "\n",
        "        # Accumulate the training loss over all of the batches so that I can\n",
        "        # calculate the average loss at the end. `loss` is a Tensor containing a\n",
        "        # single value; the `.item()` function just returns the Python value \n",
        "        # from the tensor.\n",
        "        total_loss += loss.item()\n",
        "\n",
        "        # Perform a backward pass to calculate the gradients.\n",
        "        loss.backward()\n",
        "\n",
        "        # Clip the norm of the gradients to 1.0.\n",
        "        # This is to help prevent the \"exploding gradients\" problem.\n",
        "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
        "\n",
        "        # Update parameters and take a step using the computed gradient.\n",
        "        # The optimizer dictates the \"update rule\"--how the parameters are\n",
        "        # modified based on their gradients, the learning rate, etc.\n",
        "        optimizer.step()\n",
        "\n",
        "        # Update the learning rate.\n",
        "        scheduler.step()\n",
        "\n",
        "    # Calculate the average loss over the training data.\n",
        "    avg_train_loss = total_loss / len(train_dataloader)            \n",
        "    \n",
        "    # Store the loss value for plotting the learning curve.\n",
        "    loss_values.append(avg_train_loss)\n",
        "\n",
        "    print(\"\")\n",
        "    print(\"  Average training loss: {0:.2f}\".format(avg_train_loss))\n",
        "    print(\"  Training epoch took: {:}\".format(format_time(time.time() - t0)))\n",
        "      \n",
        "print(\"\")\n",
        "print(\"Training complete!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uEe7lPtVKpIY"
      },
      "source": [
        "## Saving model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AYCSVm_wKnuM"
      },
      "outputs": [],
      "source": [
        "name = model_bert + \"_s\" + str(maxOfInstancePerClass)\n",
        "model_path = path + \"model_\"+name+\".pt\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qmsxrOqjCsGo"
      },
      "outputs": [],
      "source": [
        "#torch.save(model, model_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "model.save_pretrained(model_path)\n",
        "#ludo: changement de la façon de sauver le modèle"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pM9bSsckCndR"
      },
      "source": [
        "## Loading model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cEycmiS8Cnjw"
      },
      "outputs": [],
      "source": [
        "#model = torch.load(model_path)\n",
        "model = BertForSequenceClassification.from_pretrained(model_path).to(\"mps\") #.to(\"cuda\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VJwyfmakkQyj"
      },
      "source": [
        "## Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K9qdtYexIIvk"
      },
      "outputs": [],
      "source": [
        "def evaluate_bert(data, labels, model, batch_size):\n",
        "  # Tokenize all of the sentences and map the tokens to thier word IDs.\n",
        "  input_ids = []\n",
        "  # For every sentence...\n",
        "  for sent in data:\n",
        "      # `encode` will:\n",
        "      #   (1) Tokenize the sentence.\n",
        "      #   (2) Prepend the `[CLS]` token to the start.\n",
        "      #   (3) Append the `[SEP]` token to the end.\n",
        "      #   (4) Map tokens to their IDs.\n",
        "      encoded_sent = tokenizer.encode(\n",
        "                          str(sent),                      # Sentence to encode.\n",
        "                          add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
        "                  )\n",
        "      \n",
        "      input_ids.append(encoded_sent)\n",
        "\n",
        "  # Pad our input tokens\n",
        "  padded = []\n",
        "  for i in input_ids:\n",
        "\n",
        "    if len(i) > max_len:\n",
        "      padded.extend([i[:max_len]])\n",
        "    else:\n",
        "      padded.extend([i + [0] * (max_len - len(i))])\n",
        "  input_ids = np.array(padded)\n",
        "\n",
        "  # Create attention masks\n",
        "  attention_masks = []\n",
        "\n",
        "  # Create a mask of 1s for each token followed by 0s for padding\n",
        "  for seq in input_ids:\n",
        "      seq_mask = [float(i>0) for i in seq]\n",
        "      attention_masks.append(seq_mask) \n",
        "\n",
        "  # Convert to tensors.\n",
        "  prediction_inputs = torch.tensor(input_ids)\n",
        "  prediction_masks = torch.tensor(attention_masks)\n",
        "  prediction_labels = torch.tensor(labels)\n",
        "\n",
        "  # Create the DataLoader.\n",
        "  prediction_data = TensorDataset(prediction_inputs, prediction_masks, prediction_labels)\n",
        "  prediction_sampler = SequentialSampler(prediction_data)\n",
        "  prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)\n",
        "\n",
        "  print('Predicting labels for {:,} test sentences...'.format(len(prediction_inputs)))\n",
        "\n",
        "  # Put model in evaluation mode\n",
        "  model.eval()\n",
        "\n",
        "  # Tracking variables \n",
        "  predictions , true_labels = [], []\n",
        "\n",
        "  # Predict \n",
        "  for batch in prediction_dataloader:\n",
        "  # Add batch to GPU\n",
        "      batch = tuple(t.to(device) for t in batch)\n",
        "      \n",
        "      # Unpack the inputs from the dataloader\n",
        "      b_input_ids, b_input_mask, b_labels = batch\n",
        "      \n",
        "      # Telling the model not to compute or store gradients, saving memory and \n",
        "      # speeding up prediction\n",
        "      with torch.no_grad():\n",
        "          # Forward pass, calculate logit predictions\n",
        "          outputs = model(b_input_ids, token_type_ids=None, \n",
        "                          attention_mask=b_input_mask)\n",
        "\n",
        "      logits = outputs[0]\n",
        "      #print(logits)\n",
        "\n",
        "      # Move logits and labels to CPU\n",
        "      logits = logits.detach().cpu().numpy()\n",
        "      label_ids = b_labels.to('cpu').numpy()\n",
        "      #print(logits)\n",
        "      \n",
        "      # Store predictions and true labels\n",
        "      predictions.append(logits)\n",
        "      true_labels.append(label_ids)\n",
        "\n",
        "  print('    DONE.')\n",
        "\n",
        "\n",
        "  pred_labels = []\n",
        "\n",
        "  # Evaluate each test batch using many matrics\n",
        "  print('Calculating the matrics for each batch...')\n",
        "\n",
        "  for i in range(len(true_labels)):\n",
        "    \n",
        "    # The predictions for this batch are a 2-column ndarray (one column for \"0\" \n",
        "    # and one column for \"1\"). Pick the label with the highest value and turn this\n",
        "    # in to a list of 0s and 1s.\n",
        "    pred_labels_i = np.argmax(predictions[i], axis=1).flatten()\n",
        "    pred_labels.append(pred_labels_i)\n",
        "\n",
        "\n",
        "  pred_labels_ = [item for sublist in pred_labels for item in sublist]\n",
        "  true_labels_ = [item for sublist in true_labels for item in sublist]\n",
        "\n",
        "  return pred_labels_, true_labels_"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dPjV_5g8DDQy"
      },
      "outputs": [],
      "source": [
        "dataset = \"test\"\n",
        "\n",
        "df_eval = pd.read_csv(dataset+\"_set.tsv\", sep=\"\\t\")\n",
        "\n",
        "data_eval = df_eval[columnText].values\n",
        "\n",
        "y = df_eval[columnClass]\n",
        "\n",
        "\n",
        "\n",
        "y = encoder.transform(y)\n",
        "labels = y.tolist()\n",
        "\n",
        "\n",
        "model_path = path+\"/model_\"+model_bert+\"_s\"+str(maxOfInstancePerClass)+\".pt\"\n",
        "model = torch.load(model_path)\n",
        "\n",
        "if model_bert == \"bert-base-multilingual-cased\":\n",
        "  tokenizer = BertTokenizer.from_pretrained(model_bert)\n",
        "elif model_bert == \"camembert-base\":\n",
        "  tokenizer = CamembertTokenizer.from_pretrained(model_bert)\n",
        "\n",
        "pred_labels_, true_labels_ = evaluate_bert(data_eval, labels, model, batch_size)\n",
        "\n",
        "\n",
        "report = classification_report(true_labels_, pred_labels_,  output_dict = True)\n",
        "    \n",
        "classes = [str(e) for e in encoder.transform(encoder.classes_)]\n",
        "classesName = encoder.classes_\n",
        "\n",
        "precision = []\n",
        "recall = []\n",
        "f1 = []\n",
        "support = []\n",
        "dff = pd.DataFrame(columns= ['className', 'precision', 'recall', 'f1-score', 'support', 'FP', 'FN', 'TP', 'TN'])\n",
        "for c in classes:\n",
        "  precision.append(report[c]['precision'])\n",
        "  recall.append(report[c]['recall'])\n",
        "  f1.append(report[c]['f1-score'])\n",
        "  support.append(report[c]['support'])\n",
        "\n",
        "accuracy = report['accuracy']\n",
        "weighted_avg = report['weighted avg']\n",
        "cnf_matrix = confusion_matrix(true_labels_, pred_labels_)\n",
        "FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)\n",
        "FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)\n",
        "TP = np.diag(cnf_matrix)\n",
        "TN = cnf_matrix.sum() - (FP + FN + TP)\n",
        "\n",
        "dff['className'] = classesName\n",
        "dff['precision'] = precision\n",
        "dff['recall'] = recall\n",
        "dff['f1-score'] = f1\n",
        "dff['support'] = support\n",
        "dff['FP'] = FP\n",
        "dff['FN'] = FN\n",
        "dff['TP'] = TP\n",
        "dff['TN'] = TN\n",
        "\n",
        "print(name)\n",
        "\n",
        "name = \"test_\"+ name\n",
        "content = name + \"\\n\"\n",
        "print(name)\n",
        "content += str(weighted_avg) + \"\\n\"\n",
        "\n",
        "print(weighted_avg)\n",
        "print(accuracy)\n",
        "print(dff)\n",
        "\n",
        "dff.to_csv(path+\"/report_\"+name+\".csv\", index=False)\n",
        "# enregistrer les predictions\n",
        "pd.DataFrame({'labels': pd.Series(true_labels_), 'predictions': pd.Series(pred_labels_)}).to_csv(path+\"/predictions/predictions_\"+name+\".csv\")\n",
        "\n",
        "with open(path+\"reports/report_\"+name+\".txt\", 'w') as f:\n",
        "  f.write(content)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cVdM4eT6I8g2"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HzxyFO3knanV"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KDRPPw4Wnap7"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DX81R2dcnasF"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wgfqJFVeJMK1"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GqEf5_41JMNZ"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x_n57EvhJMQh"
      },
      "outputs": [],
      "source": [
        "model_path = \"drive/MyDrive/Classification-EDdA/model_bert-base-multilingual-cased_s10000.pt\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R3_9tA9MI8ju"
      },
      "outputs": [],
      "source": [
        "model = torch.load(model_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_fzgS5USJeAF",
        "outputId": "be4a5506-76ed-4eef-bb3c-fe2bb77c6e4d"
      },
      "outputs": [],
      "source": [
        "!wget https://projet.liris.cnrs.fr/geode/files/datasets/EDdA/Classification/LGE_withContent.tsv"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8WEJjQC7I8mP"
      },
      "outputs": [],
      "source": [
        "df_LGE = pd.read_csv(\"LGE_withContent.tsv\", sep=\"\\t\")\n",
        "data_LGE = df_LGE[\"content\"].values\n",
        "\n",
        "\n",
        "#pred_labels_, true_labels_ = evaluate_bert(data_eval, labels, model, batch_size)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "9qJDTU-6vzkk",
        "outputId": "1b279f0e-7715-4d23-f524-08e8ba327f6c"
      },
      "outputs": [],
      "source": [
        "df_LGE.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "71-fP61-OOwQ",
        "outputId": "ef08b49e-0a9f-4653-e303-3163250af35b"
      },
      "outputs": [],
      "source": [
        "df_LGE.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lFFed2EAI8oq"
      },
      "outputs": [],
      "source": [
        "def generate_prediction_dataloader(chosen_model, sentences_to_predict, batch_size = 8, max_len = 512):\n",
        "\n",
        "    if chosen_model == 'bert-base-multilingual-cased' :\n",
        "        print('Loading Bert Tokenizer...')\n",
        "        tokenizer = BertTokenizer.from_pretrained(chosen_model)\n",
        "    elif chosen_model == 'camembert-base':\n",
        "        print('Loading Camembert Tokenizer...')\n",
        "        tokenizer = CamembertTokenizer.from_pretrained(chosen_model)\n",
        "\n",
        "    # Tokenize all of the sentences and map the tokens to thier word IDs.\n",
        "    input_ids_test = []\n",
        "    # For every sentence...\n",
        "    for sent in sentences_to_predict:\n",
        "        # `encode` will:\n",
        "        #   (1) Tokenize the sentence.\n",
        "        #   (2) Prepend the `[CLS]` token to the start.\n",
        "        #   (3) Append the `[SEP]` token to the end.\n",
        "        #   (4) Map tokens to their IDs.\n",
        "        encoded_sent = tokenizer.encode(\n",
        "                            sent,                      # Sentence to encode.\n",
        "                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
        "                    )\n",
        "\n",
        "        input_ids_test.append(encoded_sent)\n",
        "\n",
        "    # Pad our input tokens\n",
        "    padded_test = []\n",
        "    for i in input_ids_test:\n",
        "\n",
        "        if len(i) > max_len:\n",
        "            padded_test.extend([i[:max_len]])\n",
        "        else:\n",
        "\n",
        "            padded_test.extend([i + [0] * (max_len - len(i))])\n",
        "    input_ids_test = np.array(padded_test)\n",
        "\n",
        "    # Create attention masks\n",
        "    attention_masks = []\n",
        "\n",
        "    # Create a mask of 1s for each token followed by 0s for padding\n",
        "    for seq in input_ids_test:\n",
        "        seq_mask = [float(i>0) for i in seq]\n",
        "        attention_masks.append(seq_mask)\n",
        "\n",
        "    # Convert to tensors.\n",
        "    prediction_inputs = torch.tensor(input_ids_test)\n",
        "    prediction_masks = torch.tensor(attention_masks)\n",
        "    #set batch size\n",
        "\n",
        "\n",
        "    # Create the DataLoader.\n",
        "    prediction_data = TensorDataset(prediction_inputs, prediction_masks)\n",
        "    prediction_sampler = SequentialSampler(prediction_data)\n",
        "    prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)\n",
        "\n",
        "    return prediction_dataloader\n",
        "\n",
        "\n",
        "\n",
        "def predict_class_bertFineTuning(model, sentences_to_predict_dataloader):\n",
        "\n",
        "\n",
        "    # If there's a GPU available...\n",
        "    if torch.cuda.is_available():\n",
        "\n",
        "        # Tell PyTorch to use the GPU.\n",
        "        device = torch.device(\"cuda\")\n",
        "\n",
        "        print('There are %d GPU(s) available.' % torch.cuda.device_count())\n",
        "\n",
        "        print('We will use the GPU:', torch.cuda.get_device_name(0))\n",
        "\n",
        "        # If not...\n",
        "    else:\n",
        "        print('No GPU available, using the CPU instead.')\n",
        "        device = torch.device(\"cpu\")\n",
        "\n",
        "    # Put model in evaluation mode\n",
        "    model.eval()\n",
        "\n",
        "    # Tracking variables\n",
        "    predictions_test , true_labels = [], []\n",
        "    pred_labels_ = []\n",
        "    # Predict\n",
        "    for batch in sentences_to_predict_dataloader:\n",
        "    # Add batch to GPU\n",
        "        batch = tuple(t.to(device) for t in batch)\n",
        "\n",
        "        # Unpack the inputs from the dataloader\n",
        "        b_input_ids, b_input_mask = batch\n",
        "\n",
        "        # Telling the model not to compute or store gradients, saving memory and\n",
        "        # speeding up prediction\n",
        "        with torch.no_grad():\n",
        "            # Forward pass, calculate logit predictions\n",
        "            outputs = model(b_input_ids, token_type_ids=None,\n",
        "                            attention_mask=b_input_mask)\n",
        "\n",
        "        logits = outputs[0]\n",
        "        #print(logits)\n",
        "\n",
        "        # Move logits and labels to CPU\n",
        "        logits = logits.detach().cpu().numpy()\n",
        "        #print(logits)\n",
        "\n",
        "        # Store predictions and true labels\n",
        "        predictions_test.append(logits)\n",
        "\n",
        "        #print('    DONE.')\n",
        "\n",
        "        pred_labels = []\n",
        "        \n",
        "        for i in range(len(predictions_test)):\n",
        "\n",
        "            # The predictions for this batch are a 2-column ndarray (one column for \"0\"\n",
        "            # and one column for \"1\"). Pick the label with the highest value and turn this\n",
        "            # in to a list of 0s and 1s.\n",
        "            pred_labels_i = np.argmax(predictions_test[i], axis=1).flatten()\n",
        "            pred_labels.append(pred_labels_i)\n",
        "\n",
        "    pred_labels_ += [item for sublist in pred_labels for item in sublist]\n",
        "    return pred_labels_\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "O9eer_kgI8rC",
        "outputId": "94ea7418-14a8-4918-e210-caf0018f5989"
      },
      "outputs": [],
      "source": [
        "data_loader = generate_prediction_dataloader('bert-base-multilingual-cased', data_LGE)\n",
        "#data_loader = generate_prediction_dataloader('camembert-base', data_LGE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sFpAwbrBwF2h",
        "outputId": "8d210732-619d-41f0-b6e2-ad9d06a85069"
      },
      "outputs": [],
      "source": [
        "p = predict_class_bertFineTuning( model, data_loader )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "51HF6-8UPSTc",
        "outputId": "26bff792-eb8d-4e1a-efa4-a7a6c9d32bf9"
      },
      "outputs": [],
      "source": [
        "len(p)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rFFGhaCvQHfh"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qgJ-O4rcQHiI",
        "outputId": "bfe93dd6-4d89-4d5c-be0d-45e1c98c6b14"
      },
      "outputs": [],
      "source": [
        "# Il faudrait enregistrer l'encoder, \n",
        "# sinon on est obligé de le refaire à partir du jeu d'entrainement pour récupérer le noms des classes.\n",
        "encoder"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QuST9wJoQHnS"
      },
      "outputs": [],
      "source": [
        "p2 = list(encoder.inverse_transform(p))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6ek7suq9QHqE",
        "outputId": "6636983a-7eba-48c8-d884-f8fb437294dc"
      },
      "outputs": [],
      "source": [
        "p2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XvdDj5PBQHtk"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t39Xs0j7QHXJ"
      },
      "outputs": [],
      "source": [
        "df_LGE['class_bert'] = p2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "-VZ7geRmQHaD",
        "outputId": "350a4122-5b1f-43e2-e372-2f628f665c4a"
      },
      "outputs": [],
      "source": [
        "df_LGE.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3xkzdkrKQHwA"
      },
      "outputs": [],
      "source": [
        "df_LGE.to_csv(\"drive/MyDrive/Classification-EDdA/classification_LGE.tsv\", sep=\"\\t\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "machine_shape": "hm",
      "name": "EDdA-Classification_BertFineTuning.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.9.13 ('geode-classification-py39')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n[Clang 13.0.1 ]"
    },
    "vscode": {
      "interpreter": {
        "hash": "16fac9c2d845f8e1f8c6fffffe3d3a0def61c7e42da17a08d00f279ad4dea797"
      }
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "0279837673b446b09aac18346213eb7e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_69004a5069094f8c9d59d5136f627bef",
              "IPY_MODEL_e96a95317b0945c58c8ff0e944c7593e",
              "IPY_MODEL_68b69c9d3a274900bc2892848f725cb0"
            ],
            "layout": "IPY_MODEL_09b5f0bbd5c14bc289b0f92a22bb29ab"
          }
        },
        "0779c8ea0ed24e64a800ae5dff6bc8ce": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8d24b669a39b4876ac0a014dff678db1",
            "max": 810912,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_99b785ea53744868b8b11e5e94936fcc",
            "value": 810912
          }
        },
        "09b5f0bbd5c14bc289b0f92a22bb29ab": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "12afa6b6474b401f9ff3f189cc0d3d11": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "152a31110bf9477989833eac91794688": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_826bd7d0a1f146ea9f7d53584468190c",
            "max": 445032417,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_e86a1d4d268c4314897b58f7bba5ec25",
            "value": 445032417
          }
        },
        "1bf6a76237454349aafc1e9284376879": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "274e505b5f354efc8de3ef26cc43e617": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_ad5e0e1439a94578a31b80c90dbf3247",
              "IPY_MODEL_0779c8ea0ed24e64a800ae5dff6bc8ce",
              "IPY_MODEL_7870340ac12b469c8ac19de3a47b6e67"
            ],
            "layout": "IPY_MODEL_f1f9d5b32f60473b86ae6b340d6c0850"
          }
        },
        "2c44d9c11e704b70aa32904a23d1790c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2cf386a8d14d43389374f79bfa922675": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2d1d632da0f740c38512c9ad779d3173": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "3592b1ed1d6d452b93c57b304943a1cb": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "4a23110523184d019a77368116f738f3": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "500826e3813b414a820aa260bfde9e23": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5032547e748f45a3b0cdd12fafe1dd05": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "58b4f9e0366f4d4eba7f902af84b8965": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_12afa6b6474b401f9ff3f189cc0d3d11",
            "placeholder": "​",
            "style": "IPY_MODEL_c4d981755d1d42b6940396b77bc251bc",
            "value": "Downloading: 100%"
          }
        },
        "5978954f56fb40928b970f32d1634aaf": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "5f321455342348f49879a9ca8b392077": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "68b69c9d3a274900bc2892848f725cb0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5032547e748f45a3b0cdd12fafe1dd05",
            "placeholder": "​",
            "style": "IPY_MODEL_c4c1675163bd4997bb44d7ea3967a356",
            "value": " 1.40M/1.40M [00:00&lt;00:00, 6.57MB/s]"
          }
        },
        "69004a5069094f8c9d59d5136f627bef": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_cb447c62ce1d4c1ea760175ae619fbb9",
            "placeholder": "​",
            "style": "IPY_MODEL_76007b17ffd2478fa4a86f959d4f1766",
            "value": "Downloading: 100%"
          }
        },
        "70dd7428d78c44409308d62ba04917de": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4a23110523184d019a77368116f738f3",
            "placeholder": "​",
            "style": "IPY_MODEL_1bf6a76237454349aafc1e9284376879",
            "value": "Downloading: 100%"
          }
        },
        "76007b17ffd2478fa4a86f959d4f1766": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7870340ac12b469c8ac19de3a47b6e67": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2c44d9c11e704b70aa32904a23d1790c",
            "placeholder": "​",
            "style": "IPY_MODEL_2cf386a8d14d43389374f79bfa922675",
            "value": " 811k/811k [00:00&lt;00:00, 2.75MB/s]"
          }
        },
        "826bd7d0a1f146ea9f7d53584468190c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8d24b669a39b4876ac0a014dff678db1": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8f467553598f4dcc9abf55da79c11018": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_58b4f9e0366f4d4eba7f902af84b8965",
              "IPY_MODEL_9383e09698ae4bd1820a4bca22e78315",
              "IPY_MODEL_a189838c4de648198b0f4fc99c29ced8"
            ],
            "layout": "IPY_MODEL_9d7a8b3ecfe74f66b4238fe085c05906"
          }
        },
        "9383e09698ae4bd1820a4bca22e78315": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fe0e3b1df104484c98fbdcd31a04e427",
            "max": 508,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_5978954f56fb40928b970f32d1634aaf",
            "value": 508
          }
        },
        "9420a47a2bf44ead8cff283f20566cda": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "99b785ea53744868b8b11e5e94936fcc": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "9d7a8b3ecfe74f66b4238fe085c05906": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a159d62667734657a49ba3a96494f137": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a189838c4de648198b0f4fc99c29ced8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_df95c20399dd4918bc7559a90886d4aa",
            "placeholder": "​",
            "style": "IPY_MODEL_2d1d632da0f740c38512c9ad779d3173",
            "value": " 508/508 [00:00&lt;00:00, 16.9kB/s]"
          }
        },
        "a9c47cb226ee41e18812f29f690992eb": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ad5e0e1439a94578a31b80c90dbf3247": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9420a47a2bf44ead8cff283f20566cda",
            "placeholder": "​",
            "style": "IPY_MODEL_5f321455342348f49879a9ca8b392077",
            "value": "Downloading: 100%"
          }
        },
        "c4c1675163bd4997bb44d7ea3967a356": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c4d981755d1d42b6940396b77bc251bc": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "cb447c62ce1d4c1ea760175ae619fbb9": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d09d664839d04303b8fef9ef895f6e4f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_70dd7428d78c44409308d62ba04917de",
              "IPY_MODEL_152a31110bf9477989833eac91794688",
              "IPY_MODEL_fcde5f4cf49846a0ad1b284aad98a38a"
            ],
            "layout": "IPY_MODEL_500826e3813b414a820aa260bfde9e23"
          }
        },
        "d4ad1a78750d49feaea584a82940bb7d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "df95c20399dd4918bc7559a90886d4aa": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e86a1d4d268c4314897b58f7bba5ec25": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "e96a95317b0945c58c8ff0e944c7593e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a9c47cb226ee41e18812f29f690992eb",
            "max": 1395301,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_d4ad1a78750d49feaea584a82940bb7d",
            "value": 1395301
          }
        },
        "f1f9d5b32f60473b86ae6b340d6c0850": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "fcde5f4cf49846a0ad1b284aad98a38a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a159d62667734657a49ba3a96494f137",
            "placeholder": "​",
            "style": "IPY_MODEL_3592b1ed1d6d452b93c57b304943a1cb",
            "value": " 445M/445M [00:14&lt;00:00, 39.2MB/s]"
          }
        },
        "fe0e3b1df104484c98fbdcd31a04e427": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}