From 48f4b22e4c0206a01a8157273e53f87bf5a514f8 Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <moncla.ludovic@gmail.com>
Date: Thu, 17 Nov 2022 22:29:16 +0100
Subject: [PATCH] Update Predict.ipynb

---
 notebooks/Predict.ipynb | 534 ++++++++++++++++------------------------
 1 file changed, 217 insertions(+), 317 deletions(-)

diff --git a/notebooks/Predict.ipynb b/notebooks/Predict.ipynb
index ab25086..f479b3e 100644
--- a/notebooks/Predict.ipynb
+++ b/notebooks/Predict.ipynb
@@ -92,7 +92,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 1,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
@@ -100,7 +100,15 @@
         "id": "dPOU-Efhf4ui",
         "outputId": "121dd21e-f98c-483d-d6d1-2838f732a4e2"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "We will use the GPU\n"
+          ]
+        }
+      ],
       "source": [
         "import torch\n",
         "\n",
@@ -131,7 +139,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 23,
       "metadata": {
         "id": "SkErnwgMMbRj"
       },
@@ -140,7 +148,7 @@
         "import pandas as pd \n",
         "import numpy as np\n",
         "\n",
-        "from transformers import BertTokenizer, CamembertTokenizer\n",
+        "from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, CamembertTokenizer, CamembertForSequenceClassification\n",
         "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler"
       ]
     },
@@ -155,7 +163,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 3,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -174,7 +182,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 4,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -184,18 +192,129 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 5,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "data": {
+            "text/html": [
+              "<div>\n",
+              "<style scoped>\n",
+              "    .dataframe tbody tr th:only-of-type {\n",
+              "        vertical-align: middle;\n",
+              "    }\n",
+              "\n",
+              "    .dataframe tbody tr th {\n",
+              "        vertical-align: top;\n",
+              "    }\n",
+              "\n",
+              "    .dataframe thead th {\n",
+              "        text-align: right;\n",
+              "    }\n",
+              "</style>\n",
+              "<table border=\"1\" class=\"dataframe\">\n",
+              "  <thead>\n",
+              "    <tr style=\"text-align: right;\">\n",
+              "      <th></th>\n",
+              "      <th>id</th>\n",
+              "      <th>tome</th>\n",
+              "      <th>rank</th>\n",
+              "      <th>domain</th>\n",
+              "      <th>remark</th>\n",
+              "      <th>content</th>\n",
+              "    </tr>\n",
+              "  </thead>\n",
+              "  <tbody>\n",
+              "    <tr>\n",
+              "      <th>0</th>\n",
+              "      <td>abrabeses-0</td>\n",
+              "      <td>1</td>\n",
+              "      <td>623</td>\n",
+              "      <td>geography</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>ABRABESES. Village d’Espagne de la prov. de Za...</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>1</th>\n",
+              "      <td>accius-0</td>\n",
+              "      <td>1</td>\n",
+              "      <td>1076</td>\n",
+              "      <td>biography</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>ACCIUS, L. ou L. ATTIUS (170-94 av. J.-C.), po...</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>2</th>\n",
+              "      <td>achenbach-2</td>\n",
+              "      <td>1</td>\n",
+              "      <td>1357</td>\n",
+              "      <td>biography</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>ACHENBACH(Henri), administrateur prussien, né ...</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>3</th>\n",
+              "      <td>acireale-0</td>\n",
+              "      <td>1</td>\n",
+              "      <td>1513</td>\n",
+              "      <td>geography</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>ACIREALE. Yille de Sicile, de la province et d...</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <th>4</th>\n",
+              "      <td>actée-0</td>\n",
+              "      <td>1</td>\n",
+              "      <td>1731</td>\n",
+              "      <td>botany</td>\n",
+              "      <td>NaN</td>\n",
+              "      <td>ACTÉE(Actœa L.). Genre de plantes de la famill...</td>\n",
+              "    </tr>\n",
+              "  </tbody>\n",
+              "</table>\n",
+              "</div>"
+            ],
+            "text/plain": [
+              "            id  tome  rank     domain remark  \\\n",
+              "0  abrabeses-0     1   623  geography    NaN   \n",
+              "1     accius-0     1  1076  biography    NaN   \n",
+              "2  achenbach-2     1  1357  biography    NaN   \n",
+              "3   acireale-0     1  1513  geography    NaN   \n",
+              "4      actée-0     1  1731     botany    NaN   \n",
+              "\n",
+              "                                             content  \n",
+              "0  ABRABESES. Village d’Espagne de la prov. de Za...  \n",
+              "1  ACCIUS, L. ou L. ATTIUS (170-94 av. J.-C.), po...  \n",
+              "2  ACHENBACH(Henri), administrateur prussien, né ...  \n",
+              "3  ACIREALE. Yille de Sicile, de la province et d...  \n",
+              "4  ACTÉE(Actœa L.). Genre de plantes de la famill...  "
+            ]
+          },
+          "execution_count": 5,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
       "source": [
         "df_LGE.head()"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 6,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "data": {
+            "text/plain": [
+              "(310, 6)"
+            ]
+          },
+          "execution_count": 6,
+          "metadata": {},
+          "output_type": "execute_result"
+        }
+      ],
       "source": [
         "df_LGE.shape"
       ]
@@ -211,7 +330,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 7,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -222,7 +341,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 8,
       "metadata": {},
       "outputs": [],
       "source": [
@@ -240,6 +359,11 @@
         "        encoded_sent = tokenizer.encode(\n",
         "                            sent,                      # Sentence to encode.\n",
         "                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
+        "                            # This function also supports truncation and conversion\n",
+        "                            # to pytorch tensors, but I need to do padding, so I\n",
+        "                            # can't use these features.\n",
+        "                            #max_length = max_len,          # Truncate all sentences.\n",
+        "                            #return_tensors = 'pt',     # Return pytorch tensors.\n",
         "                    )\n",
         "        input_ids_test.append(encoded_sent)\n",
         "\n",
@@ -261,15 +385,15 @@
         "        attention_masks.append(seq_mask)\n",
         "\n",
         "    # Convert to tensors.\n",
-        "    prediction_inputs = torch.tensor(input_ids_test)\n",
-        "    prediction_masks = torch.tensor(attention_masks)\n",
+        "    inputs = torch.tensor(input_ids_test)\n",
+        "    masks = torch.tensor(attention_masks)\n",
         "    #set batch size\n",
         "\n",
         "    # Create the DataLoader.\n",
-        "    prediction_data = TensorDataset(prediction_inputs, prediction_masks)\n",
-        "    prediction_sampler = SequentialSampler(prediction_data)\n",
+        "    data = TensorDataset(inputs, masks)\n",
+        "    prediction_sampler = SequentialSampler(data)\n",
         "\n",
-        "    return DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)\n",
+        "    return DataLoader(data, sampler=prediction_sampler, batch_size=batch_size)\n",
         "\n",
         "\n",
         "\n",
@@ -321,18 +445,17 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 12,
       "metadata": {},
-      "outputs": [],
-      "source": [
-        "model = torch.load(model_path, map_location=torch.device('mps'))"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Loading Bert Tokenizer...\n"
+          ]
+        }
+      ],
       "source": [
         "if model_name == 'bert-base-multilingual-cased' :\n",
         "    print('Loading Bert Tokenizer...')\n",
@@ -344,348 +467,125 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 13,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "Token indices sequence length is longer than the specified maximum sequence length for this model (1204 > 512). Running this sequence through the model will result in indexing errors\n"
+          ]
+        }
+      ],
       "source": [
         "data_loader = generate_dataloader(tokenizer, data_LGE)"
       ]
     },
     {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "_fzgS5USJeAF",
-        "outputId": "be4a5506-76ed-4eef-bb3c-fe2bb77c6e4d"
-      },
-      "outputs": [],
-      "source": []
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {
-        "id": "8WEJjQC7I8mP"
-      },
-      "outputs": [],
+      "cell_type": "markdown",
+      "metadata": {},
       "source": [
-        "df_LGE = pd.read_csv(\"LGE_withContent.tsv\", sep=\"\\t\")\n",
-        "data_LGE = df_LGE[\"content\"].values\n",
         "\n",
+        "https://discuss.huggingface.co/t/an-efficient-way-of-loading-a-model-that-was-saved-with-torch-save/9814\n",
         "\n",
-        "#pred_labels_, true_labels_ = evaluate_bert(data_eval, labels, model, batch_size)\n"
+        "https://github.com/huggingface/transformers/issues/2094\n"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/",
-          "height": 206
-        },
-        "id": "9qJDTU-6vzkk",
-        "outputId": "1b279f0e-7715-4d23-f524-08e8ba327f6c"
-      },
-      "outputs": [],
-      "source": [
-        "df_LGE.head()"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "71-fP61-OOwQ",
-        "outputId": "ef08b49e-0a9f-4653-e303-3163250af35b"
-      },
-      "outputs": [],
-      "source": [
-        "df_LGE.shape"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {
-        "id": "lFFed2EAI8oq"
-      },
-      "outputs": [],
+      "execution_count": 26,
+      "metadata": {},
+      "outputs": [
+        {
+          "ename": "TypeError",
+          "evalue": "Expected state_dict to be dict-like, got <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'>.",
+          "output_type": "error",
+          "traceback": [
+            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
+            "Cell \u001b[0;32mIn [26], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[39m#model = torch.load(model_path, map_location=torch.device('mps'))\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m model\u001b[39m.\u001b[39;49mload_state_dict(torch\u001b[39m.\u001b[39;49mload(model_path, map_location\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mdevice(\u001b[39m'\u001b[39;49m\u001b[39mmps\u001b[39;49m\u001b[39m'\u001b[39;49m)))\n",
+            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1620\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m   1597\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"Copies parameters and buffers from :attr:`state_dict` into\u001b[39;00m\n\u001b[1;32m   1598\u001b[0m \u001b[39mthis module and its descendants. If :attr:`strict` is ``True``, then\u001b[39;00m\n\u001b[1;32m   1599\u001b[0m \u001b[39mthe keys of :attr:`state_dict` must exactly match the keys returned\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1617\u001b[0m \u001b[39m    ``RuntimeError``.\u001b[39;00m\n\u001b[1;32m   1618\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   1619\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(state_dict, Mapping):\n\u001b[0;32m-> 1620\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mExpected state_dict to be dict-like, got \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\u001b[39mtype\u001b[39m(state_dict)))\n\u001b[1;32m   1622\u001b[0m missing_keys: List[\u001b[39mstr\u001b[39m] \u001b[39m=\u001b[39m []\n\u001b[1;32m   1623\u001b[0m unexpected_keys: List[\u001b[39mstr\u001b[39m] \u001b[39m=\u001b[39m []\n",
+            "\u001b[0;31mTypeError\u001b[0m: Expected state_dict to be dict-like, got <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'>."
+          ]
+        }
+      ],
       "source": [
-        "def generate_prediction_dataloader(chosen_model, sentences_to_predict, batch_size = 8, max_len = 512):\n",
-        "\n",
-        "    if chosen_model == 'bert-base-multilingual-cased' :\n",
-        "        print('Loading Bert Tokenizer...')\n",
-        "        tokenizer = BertTokenizer.from_pretrained(chosen_model)\n",
-        "    elif chosen_model == 'camembert-base':\n",
-        "        print('Loading Camembert Tokenizer...')\n",
-        "        tokenizer = CamembertTokenizer.from_pretrained(chosen_model)\n",
-        "\n",
-        "    # Tokenize all of the sentences and map the tokens to thier word IDs.\n",
-        "    input_ids_test = []\n",
-        "    # For every sentence...\n",
-        "    for sent in sentences_to_predict:\n",
-        "        # `encode` will:\n",
-        "        #   (1) Tokenize the sentence.\n",
-        "        #   (2) Prepend the `[CLS]` token to the start.\n",
-        "        #   (3) Append the `[SEP]` token to the end.\n",
-        "        #   (4) Map tokens to their IDs.\n",
-        "        encoded_sent = tokenizer.encode(\n",
-        "                            sent,                      # Sentence to encode.\n",
-        "                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
-        "                    )\n",
-        "\n",
-        "        input_ids_test.append(encoded_sent)\n",
-        "\n",
-        "    # Pad our input tokens\n",
-        "    padded_test = []\n",
-        "    for i in input_ids_test:\n",
-        "\n",
-        "        if len(i) > max_len:\n",
-        "            padded_test.extend([i[:max_len]])\n",
-        "        else:\n",
-        "\n",
-        "            padded_test.extend([i + [0] * (max_len - len(i))])\n",
-        "    input_ids_test = np.array(padded_test)\n",
-        "\n",
-        "    # Create attention masks\n",
-        "    attention_masks = []\n",
-        "\n",
-        "    # Create a mask of 1s for each token followed by 0s for padding\n",
-        "    for seq in input_ids_test:\n",
-        "        seq_mask = [float(i>0) for i in seq]\n",
-        "        attention_masks.append(seq_mask)\n",
-        "\n",
-        "    # Convert to tensors.\n",
-        "    prediction_inputs = torch.tensor(input_ids_test)\n",
-        "    prediction_masks = torch.tensor(attention_masks)\n",
-        "    #set batch size\n",
-        "\n",
-        "\n",
-        "    # Create the DataLoader.\n",
-        "    prediction_data = TensorDataset(prediction_inputs, prediction_masks)\n",
-        "    prediction_sampler = SequentialSampler(prediction_data)\n",
-        "    prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)\n",
-        "\n",
-        "    return prediction_dataloader\n",
-        "\n",
-        "\n",
-        "\n",
-        "def predict_class_bertFineTuning(model, sentences_to_predict_dataloader):\n",
-        "\n",
-        "\n",
-        "    # If there's a GPU available...\n",
-        "    if torch.cuda.is_available():\n",
-        "\n",
-        "        # Tell PyTorch to use the GPU.\n",
-        "        device = torch.device(\"cuda\")\n",
-        "\n",
-        "        print('There are %d GPU(s) available.' % torch.cuda.device_count())\n",
-        "\n",
-        "        print('We will use the GPU:', torch.cuda.get_device_name(0))\n",
-        "\n",
-        "        # If not...\n",
-        "    else:\n",
-        "        print('No GPU available, using the CPU instead.')\n",
-        "        device = torch.device(\"cpu\")\n",
+        "#model = torch.load(model_path, map_location=torch.device('mps'))\n",
+        "#model.load_state_dict(torch.load(model_path, map_location=torch.device('mps')))\n",
         "\n",
-        "    # Put model in evaluation mode\n",
-        "    model.eval()\n",
-        "\n",
-        "    # Tracking variables\n",
-        "    predictions_test , true_labels = [], []\n",
-        "    pred_labels_ = []\n",
-        "    # Predict\n",
-        "    for batch in sentences_to_predict_dataloader:\n",
-        "    # Add batch to GPU\n",
-        "        batch = tuple(t.to(device) for t in batch)\n",
-        "\n",
-        "        # Unpack the inputs from the dataloader\n",
-        "        b_input_ids, b_input_mask = batch\n",
-        "\n",
-        "        # Telling the model not to compute or store gradients, saving memory and\n",
-        "        # speeding up prediction\n",
-        "        with torch.no_grad():\n",
-        "            # Forward pass, calculate logit predictions\n",
-        "            outputs = model(b_input_ids, token_type_ids=None,\n",
-        "                            attention_mask=b_input_mask)\n",
-        "\n",
-        "        logits = outputs[0]\n",
-        "        #print(logits)\n",
-        "\n",
-        "        # Move logits and labels to CPU\n",
-        "        logits = logits.detach().cpu().numpy()\n",
-        "        #print(logits)\n",
-        "\n",
-        "        # Store predictions and true labels\n",
-        "        predictions_test.append(logits)\n",
-        "\n",
-        "        #print('    DONE.')\n",
-        "\n",
-        "        pred_labels = []\n",
-        "        \n",
-        "        for i in range(len(predictions_test)):\n",
-        "\n",
-        "            # The predictions for this batch are a 2-column ndarray (one column for \"0\"\n",
-        "            # and one column for \"1\"). Pick the label with the highest value and turn this\n",
-        "            # in to a list of 0s and 1s.\n",
-        "            pred_labels_i = np.argmax(predictions_test[i], axis=1).flatten()\n",
-        "            pred_labels.append(pred_labels_i)\n",
-        "\n",
-        "    pred_labels_ += [item for sublist in pred_labels for item in sublist]\n",
-        "    return pred_labels_\n"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "O9eer_kgI8rC",
-        "outputId": "94ea7418-14a8-4918-e210-caf0018f5989"
-      },
-      "outputs": [],
-      "source": [
-        "data_loader = generate_prediction_dataloader('bert-base-multilingual-cased', data_LGE)\n",
-        "#data_loader = generate_prediction_dataloader('camembert-base', data_LGE)"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "sFpAwbrBwF2h",
-        "outputId": "8d210732-619d-41f0-b6e2-ad9d06a85069"
-      },
-      "outputs": [],
-      "source": [
-        "p = predict_class_bertFineTuning( model, data_loader )"
+        "model = BertForSequenceClassification.from_pretrained(model_path).to(\"mps\") #.to(\"cuda\")"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 14,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "id": "51HF6-8UPSTc",
-        "outputId": "26bff792-eb8d-4e1a-efa4-a7a6c9d32bf9"
+        "id": "_fzgS5USJeAF",
+        "outputId": "be4a5506-76ed-4eef-bb3c-fe2bb77c6e4d"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "ename": "AttributeError",
+          "evalue": "'BertEncoder' object has no attribute 'gradient_checkpointing'",
+          "output_type": "error",
+          "traceback": [
+            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+            "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
+            "Cell \u001b[0;32mIn [14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m pred \u001b[39m=\u001b[39m predict(model, data_loader, device)\n",
+            "Cell \u001b[0;32mIn [8], line 68\u001b[0m, in \u001b[0;36mpredict\u001b[0;34m(model, dataloader, device)\u001b[0m\n\u001b[1;32m     64\u001b[0m \u001b[39m# Telling the model not to compute or store gradients, saving memory and\u001b[39;00m\n\u001b[1;32m     65\u001b[0m \u001b[39m# speeding up prediction\u001b[39;00m\n\u001b[1;32m     66\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[1;32m     67\u001b[0m     \u001b[39m# Forward pass, calculate logit predictions\u001b[39;00m\n\u001b[0;32m---> 68\u001b[0m     outputs \u001b[39m=\u001b[39m model(b_input_ids, token_type_ids\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m     69\u001b[0m                     attention_mask\u001b[39m=\u001b[39;49mb_input_mask)\n\u001b[1;32m     71\u001b[0m logits \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m     72\u001b[0m \u001b[39m#print(logits)\u001b[39;00m\n\u001b[1;32m     73\u001b[0m \n\u001b[1;32m     74\u001b[0m \u001b[39m# Move logits and labels to CPU ???\u001b[39;00m\n",
+            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1191\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
+            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:1552\u001b[0m, in \u001b[0;36mBertForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1544\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   1545\u001b[0m \u001b[39mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m   1546\u001b[0m \u001b[39m    Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[1;32m   1547\u001b[0m \u001b[39m    config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\u001b[39;00m\n\u001b[1;32m   1548\u001b[0m \u001b[39m    `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\u001b[39;00m\n\u001b[1;32m   1549\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   1550\u001b[0m return_dict \u001b[39m=\u001b[39m return_dict \u001b[39mif\u001b[39;00m return_dict \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39muse_return_dict\n\u001b[0;32m-> 1552\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbert(\n\u001b[1;32m   1553\u001b[0m     input_ids,\n\u001b[1;32m   1554\u001b[0m     attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m   1555\u001b[0m     token_type_ids\u001b[39m=\u001b[39;49mtoken_type_ids,\n\u001b[1;32m   1556\u001b[0m     position_ids\u001b[39m=\u001b[39;49mposition_ids,\n\u001b[1;32m   1557\u001b[0m     head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m   1558\u001b[0m     inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m   1559\u001b[0m     output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m   1560\u001b[0m     output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m   1561\u001b[0m     return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m   1562\u001b[0m )\n\u001b[1;32m   1564\u001b[0m pooled_output \u001b[39m=\u001b[39m outputs[\u001b[39m1\u001b[39m]\n\u001b[1;32m   1566\u001b[0m pooled_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdropout(pooled_output)\n",
+            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1191\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
+            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:1014\u001b[0m, in \u001b[0;36mBertModel.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1005\u001b[0m head_mask \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_head_mask(head_mask, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mnum_hidden_layers)\n\u001b[1;32m   1007\u001b[0m embedding_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings(\n\u001b[1;32m   1008\u001b[0m     input_ids\u001b[39m=\u001b[39minput_ids,\n\u001b[1;32m   1009\u001b[0m     position_ids\u001b[39m=\u001b[39mposition_ids,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1012\u001b[0m     past_key_values_length\u001b[39m=\u001b[39mpast_key_values_length,\n\u001b[1;32m   1013\u001b[0m )\n\u001b[0;32m-> 1014\u001b[0m encoder_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mencoder(\n\u001b[1;32m   1015\u001b[0m     embedding_output,\n\u001b[1;32m   1016\u001b[0m     attention_mask\u001b[39m=\u001b[39;49mextended_attention_mask,\n\u001b[1;32m   1017\u001b[0m     head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m   1018\u001b[0m     encoder_hidden_states\u001b[39m=\u001b[39;49mencoder_hidden_states,\n\u001b[1;32m   1019\u001b[0m     encoder_attention_mask\u001b[39m=\u001b[39;49mencoder_extended_attention_mask,\n\u001b[1;32m   1020\u001b[0m     past_key_values\u001b[39m=\u001b[39;49mpast_key_values,\n\u001b[1;32m   1021\u001b[0m     use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m   1022\u001b[0m     output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m   1023\u001b[0m     output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m   1024\u001b[0m     return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m   1025\u001b[0m )\n\u001b[1;32m   1026\u001b[0m sequence_output \u001b[39m=\u001b[39m encoder_outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m   1027\u001b[0m pooled_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpooler(sequence_output) \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpooler \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n",
+            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1186\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1187\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1188\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1189\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1191\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
+            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:580\u001b[0m, in \u001b[0;36mBertEncoder.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    577\u001b[0m layer_head_mask \u001b[39m=\u001b[39m head_mask[i] \u001b[39mif\u001b[39;00m head_mask \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m    578\u001b[0m past_key_value \u001b[39m=\u001b[39m past_key_values[i] \u001b[39mif\u001b[39;00m past_key_values \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m--> 580\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgradient_checkpointing \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining:\n\u001b[1;32m    582\u001b[0m     \u001b[39mif\u001b[39;00m use_cache:\n\u001b[1;32m    583\u001b[0m         logger\u001b[39m.\u001b[39mwarning(\n\u001b[1;32m    584\u001b[0m             \u001b[39m\"\u001b[39m\u001b[39m`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    585\u001b[0m         )\n",
+            "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/geode-classification-py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1265\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1263\u001b[0m     \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m modules:\n\u001b[1;32m   1264\u001b[0m         \u001b[39mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1265\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m object has no attribute \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m   1266\u001b[0m     \u001b[39mtype\u001b[39m(\u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, name))\n",
+            "\u001b[0;31mAttributeError\u001b[0m: 'BertEncoder' object has no attribute 'gradient_checkpointing'"
+          ]
+        }
+      ],
       "source": [
-        "len(p)"
+        "pred = predict(model, data_loader, device)"
       ]
     },
     {
       "cell_type": "code",
       "execution_count": null,
-      "metadata": {
-        "id": "rFFGhaCvQHfh"
-      },
+      "metadata": {},
       "outputs": [],
       "source": []
     },
     {
       "cell_type": "code",
       "execution_count": null,
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "qgJ-O4rcQHiI",
-        "outputId": "bfe93dd6-4d89-4d5c-be0d-45e1c98c6b14"
-      },
-      "outputs": [],
-      "source": [
-        "# Il faudrait enregistrer l'encoder, \n",
-        "# sinon on est obligé de le refaire à partir du jeu d'entrainement pour récupérer le noms des classes.\n",
-        "encoder"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {
-        "id": "QuST9wJoQHnS"
-      },
-      "outputs": [],
-      "source": [
-        "p2 = list(encoder.inverse_transform(p))"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "6ek7suq9QHqE",
-        "outputId": "6636983a-7eba-48c8-d884-f8fb437294dc"
-      },
-      "outputs": [],
-      "source": [
-        "p2"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {
-        "id": "XvdDj5PBQHtk"
-      },
+      "metadata": {},
       "outputs": [],
       "source": []
     },
     {
       "cell_type": "code",
       "execution_count": null,
-      "metadata": {
-        "id": "t39Xs0j7QHXJ"
-      },
+      "metadata": {},
       "outputs": [],
-      "source": [
-        "df_LGE['class_bert'] = p2"
-      ]
+      "source": []
     },
     {
       "cell_type": "code",
       "execution_count": null,
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/",
-          "height": 206
-        },
-        "id": "-VZ7geRmQHaD",
-        "outputId": "350a4122-5b1f-43e2-e372-2f628f665c4a"
-      },
+      "metadata": {},
       "outputs": [],
-      "source": [
-        "df_LGE.head()"
-      ]
+      "source": []
     },
     {
       "cell_type": "code",
       "execution_count": null,
-      "metadata": {
-        "id": "3xkzdkrKQHwA"
-      },
+      "metadata": {},
       "outputs": [],
-      "source": [
-        "df_LGE.to_csv(\"drive/MyDrive/Classification-EDdA/classification_LGE.tsv\", sep=\"\\t\")"
-      ]
+      "source": []
     }
   ],
   "metadata": {
-- 
GitLab