From dfede5ac655f023b40ed83dd5c29dfe24988db85 Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <moncla.ludovic@gmail.com>
Date: Fri, 25 Nov 2022 14:38:58 +0100
Subject: [PATCH] Update Predict_LGE.ipynb

---
 notebooks/Predict_LGE.ipynb | 652 ++++++++++--------------------------
 1 file changed, 178 insertions(+), 474 deletions(-)

diff --git a/notebooks/Predict_LGE.ipynb b/notebooks/Predict_LGE.ipynb
index dc8acc9..b74e579 100644
--- a/notebooks/Predict_LGE.ipynb
+++ b/notebooks/Predict_LGE.ipynb
@@ -81,13 +81,41 @@
         "drive.mount('/content/drive')"
       ]
     },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "wSqbrupGMc1M"
+      },
+      "source": [
+        "### 1.2 Import librairies"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 2,
+      "metadata": {
+        "id": "SkErnwgMMbRj"
+      },
+      "outputs": [],
+      "source": [
+        "import os\n",
+        "import pandas as pd \n",
+        "import numpy as np\n",
+        "import pickle \n",
+        "import torch\n",
+        "from tqdm import tqdm\n",
+        "\n",
+        "from transformers import BertTokenizer, BertForSequenceClassification, CamembertTokenizer, CamembertForSequenceClassification\n",
+        "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler"
+      ]
+    },
     {
       "cell_type": "markdown",
       "metadata": {
         "id": "8hzEGHl7gmzk"
       },
       "source": [
-        "### 1.2 Setup GPU"
+        "### 1.3 Setup GPU"
       ]
     },
     {
@@ -110,46 +138,152 @@
         }
       ],
       "source": [
-        "import torch\n",
-        "\n",
-        "# If there's a GPU available...\n",
+        "  # If there's a GPU available...\n",
         "if torch.cuda.is_available():    \n",
         "    # Tell PyTorch to use the GPU.    \n",
         "    device = torch.device(\"cuda\")\n",
+        "    gpu_name = \"cuda\"\n",
         "    print('There are %d GPU(s) available.' % torch.cuda.device_count())\n",
         "    print('We will use the GPU:', torch.cuda.get_device_name(0))\n",
-        "\n",
         "# for MacOS\n",
         "elif torch.backends.mps.is_available() and torch.backends.mps.is_built():\n",
         "    device = torch.device(\"mps\")\n",
+        "    gpu_name = \"mps\"\n",
         "    print('We will use the GPU')\n",
         "else:\n",
         "    device = torch.device(\"cpu\")\n",
+        "    gpu_name = \"cpu\"\n",
         "    print('No GPU available, using the CPU instead.')"
       ]
     },
     {
       "cell_type": "markdown",
-      "metadata": {
-        "id": "wSqbrupGMc1M"
-      },
+      "metadata": {},
       "source": [
-        "### 1.3 Import librairies"
+        "## 2. Utils"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 2,
-      "metadata": {
-        "id": "SkErnwgMMbRj"
-      },
+      "execution_count": null,
+      "metadata": {},
       "outputs": [],
       "source": [
-        "import pandas as pd \n",
-        "import numpy as np\n",
+        "def generate_dataloader(tokenizer, sentences, batch_size = 8, max_len = 512):\n",
         "\n",
-        "from transformers import BertTokenizer, BertForSequenceClassification, CamembertTokenizer, CamembertForSequenceClassification\n",
-        "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler"
+        "    # Tokenize all of the sentences and map the tokens to thier word IDs.\n",
+        "    input_ids_test = []\n",
+        "    # For every sentence...\n",
+        "    for sent in sentences:\n",
+        "        # `encode` will:\n",
+        "        #   (1) Tokenize the sentence.\n",
+        "        #   (2) Prepend the `[CLS]` token to the start.\n",
+        "        #   (3) Append the `[SEP]` token to the end.\n",
+        "        #   (4) Map tokens to their IDs.\n",
+        "        encoded_sent = tokenizer.encode(\n",
+        "                            sent,                      # Sentence to encode.\n",
+        "                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
+        "                            # This function also supports truncation and conversion\n",
+        "                            # to pytorch tensors, but I need to do padding, so I\n",
+        "                            # can't use these features.\n",
+        "                            #max_length = max_len,          # Truncate all sentences.\n",
+        "                            #return_tensors = 'pt',     # Return pytorch tensors.\n",
+        "                    )\n",
+        "        input_ids_test.append(encoded_sent)\n",
+        "\n",
+        "    # Pad our input tokens\n",
+        "    padded_test = []\n",
+        "    for i in input_ids_test:\n",
+        "        if len(i) > max_len:\n",
+        "            padded_test.extend([i[:max_len]])\n",
+        "        else:\n",
+        "            padded_test.extend([i + [0] * (max_len - len(i))])\n",
+        "    input_ids_test = np.array(padded_test)\n",
+        "\n",
+        "    # Create attention masks\n",
+        "    attention_masks = []\n",
+        "\n",
+        "    # Create a mask of 1s for each token followed by 0s for padding\n",
+        "    for seq in input_ids_test:\n",
+        "        seq_mask = [float(i>0) for i in seq]\n",
+        "        attention_masks.append(seq_mask)\n",
+        "\n",
+        "    # Convert to tensors.\n",
+        "    inputs = torch.tensor(input_ids_test)\n",
+        "    masks = torch.tensor(attention_masks)\n",
+        "    #set batch size\n",
+        "\n",
+        "    # Create the DataLoader.\n",
+        "    data = TensorDataset(inputs, masks)\n",
+        "    prediction_sampler = SequentialSampler(data)\n",
+        "\n",
+        "    return DataLoader(data, sampler=prediction_sampler, batch_size=batch_size)\n",
+        "\n",
+        "\n",
+        "def predict(model, dataloader, device):\n",
+        "\n",
+        "    # Put model in evaluation mode\n",
+        "    model.eval()\n",
+        "\n",
+        "    # Tracking variables\n",
+        "    predictions_test , true_labels = [], []\n",
+        "    pred_labels_ = []\n",
+        "    # Predict\n",
+        "    for batch in dataloader:\n",
+        "    # Add batch to GPU\n",
+        "        batch = tuple(t.to(device) for t in batch)\n",
+        "\n",
+        "        # Unpack the inputs from the dataloader\n",
+        "        b_input_ids, b_input_mask = batch\n",
+        "\n",
+        "        # Telling the model not to compute or store gradients, saving memory and\n",
+        "        # speeding up prediction\n",
+        "        with torch.no_grad():\n",
+        "            # Forward pass, calculate logit predictions\n",
+        "            outputs = model(b_input_ids, token_type_ids=None,\n",
+        "                            attention_mask=b_input_mask)\n",
+        "        logits = outputs[0]\n",
+        "        #print(logits)\n",
+        "\n",
+        "        # Move logits and labels to CPU ???\n",
+        "        logits = logits.detach().cpu().numpy()\n",
+        "        #print(logits)\n",
+        "\n",
+        "        # Store predictions and true labels\n",
+        "        predictions_test.append(logits)\n",
+        "\n",
+        "        pred_labels = []\n",
+        "        \n",
+        "        for i in range(len(predictions_test)):\n",
+        "            # The predictions for this batch are a 2-column ndarray (one column for \"0\"\n",
+        "            # and one column for \"1\"). Pick the label with the highest value and turn this\n",
+        "            # in to a list of 0s and 1s.\n",
+        "            pred_labels_i = np.argmax(predictions_test[i], axis=1).flatten()\n",
+        "            pred_labels.append(pred_labels_i)\n",
+        "\n",
+        "    pred_labels_ += [item for sublist in pred_labels for item in sublist]\n",
+        "    return pred_labels_\n",
+        "\n",
+        "\n",
+        "def text_folder_to_dataframe(path):\n",
+        "\n",
+        "  data = []\n",
+        "  # id,tome,filename,nb_words,content,domain\n",
+        "\n",
+        "  for tome in sorted(os.listdir(path)):\n",
+        "    try:\n",
+        "        for article in tqdm(sorted(os.listdir(path + \"/\" + tome))):\n",
+        "            filename = article[:-4]\n",
+        "            id = tome + filename\n",
+        "\n",
+        "            if article[-4:] == \".txt\":\n",
+        "                with open(path + \"/\" + tome + \"/\" + article) as f:\n",
+        "                    content = f.read()\n",
+        "\n",
+        "                    data.append([id, tome, filename, content, len(content.split(' '))])\n",
+        "    except NotADirectoryError:\n",
+        "        pass\n",
+        "  return pd.DataFrame(data, columns=['id', 'tome', 'filename', 'content', 'nb_words'])\n"
       ]
     },
     {
@@ -158,7 +292,7 @@
         "id": "c5QKcXulhNJ-"
       },
       "source": [
-        "## 2. Load Data"
+        "## 3. Load Data"
       ]
     },
     {
@@ -167,8 +301,16 @@
       "metadata": {},
       "outputs": [],
       "source": [
-        "#path = \"drive/MyDrive/Classification-EDdA/\"\n",
-        "path = \"../\""
+        "!wget https://api.nakala.fr/data/10.34847/nkl.74eb1xfd/e522413b58b04ab7c283f8fa68642e9cb69ab5c5"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "!unzip e522413b58b04ab7c283f8fa68642e9cb69ab5c5"
       ]
     },
     {
@@ -177,7 +319,8 @@
       "metadata": {},
       "outputs": [],
       "source": [
-        "!wget https://projet.liris.cnrs.fr/geode/files/datasets/EDdA/Classification/LGE_withContent.tsv"
+        "#input_path = \"/Users/lmoncla/Documents/Data/Corpus/LGE/Text\"\n",
+        "input_path = \"./Text\""
       ]
     },
     {
@@ -186,7 +329,8 @@
       "metadata": {},
       "outputs": [],
       "source": [
-        "df_LGE = pd.read_csv(path + \"data/LGE_withContent.tsv\", sep=\"\\t\")\n",
+        "df_LGE = text_folder_to_dataframe(input_path)\n",
+        "#df_LGE = pd.read_csv(path + \"data/LGE_withContent.tsv\", sep=\"\\t\")\n",
         "data_LGE = df_LGE[\"content\"].values"
       ]
     },
@@ -334,115 +478,13 @@
       "metadata": {},
       "outputs": [],
       "source": [
+        "#path = \"drive/MyDrive/Classification-EDdA/\"\n",
+        "path = \"../\"\n",
         "model_name = \"bert-base-multilingual-cased\"\n",
         "#model_name = \"camembert-base\"\n",
         "model_path = path + \"models/model_\" + model_name + \"_s10000.pt\""
       ]
     },
-    {
-      "cell_type": "code",
-      "execution_count": 15,
-      "metadata": {},
-      "outputs": [],
-      "source": [
-        "def generate_dataloader(tokenizer, sentences, batch_size = 8, max_len = 512):\n",
-        "\n",
-        "    # Tokenize all of the sentences and map the tokens to thier word IDs.\n",
-        "    input_ids_test = []\n",
-        "    # For every sentence...\n",
-        "    for sent in sentences:\n",
-        "        # `encode` will:\n",
-        "        #   (1) Tokenize the sentence.\n",
-        "        #   (2) Prepend the `[CLS]` token to the start.\n",
-        "        #   (3) Append the `[SEP]` token to the end.\n",
-        "        #   (4) Map tokens to their IDs.\n",
-        "        encoded_sent = tokenizer.encode(\n",
-        "                            sent,                      # Sentence to encode.\n",
-        "                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
-        "                            # This function also supports truncation and conversion\n",
-        "                            # to pytorch tensors, but I need to do padding, so I\n",
-        "                            # can't use these features.\n",
-        "                            #max_length = max_len,          # Truncate all sentences.\n",
-        "                            #return_tensors = 'pt',     # Return pytorch tensors.\n",
-        "                    )\n",
-        "        input_ids_test.append(encoded_sent)\n",
-        "\n",
-        "    # Pad our input tokens\n",
-        "    padded_test = []\n",
-        "    for i in input_ids_test:\n",
-        "        if len(i) > max_len:\n",
-        "            padded_test.extend([i[:max_len]])\n",
-        "        else:\n",
-        "            padded_test.extend([i + [0] * (max_len - len(i))])\n",
-        "    input_ids_test = np.array(padded_test)\n",
-        "\n",
-        "    # Create attention masks\n",
-        "    attention_masks = []\n",
-        "\n",
-        "    # Create a mask of 1s for each token followed by 0s for padding\n",
-        "    for seq in input_ids_test:\n",
-        "        seq_mask = [float(i>0) for i in seq]\n",
-        "        attention_masks.append(seq_mask)\n",
-        "\n",
-        "    # Convert to tensors.\n",
-        "    inputs = torch.tensor(input_ids_test)\n",
-        "    masks = torch.tensor(attention_masks)\n",
-        "    #set batch size\n",
-        "\n",
-        "    # Create the DataLoader.\n",
-        "    data = TensorDataset(inputs, masks)\n",
-        "    prediction_sampler = SequentialSampler(data)\n",
-        "\n",
-        "    return DataLoader(data, sampler=prediction_sampler, batch_size=batch_size)\n",
-        "\n",
-        "\n",
-        "\n",
-        "def predict(model, dataloader, device):\n",
-        "\n",
-        "    # Put model in evaluation mode\n",
-        "    model.eval()\n",
-        "\n",
-        "    # Tracking variables\n",
-        "    predictions_test , true_labels = [], []\n",
-        "    pred_labels_ = []\n",
-        "    # Predict\n",
-        "    for batch in dataloader:\n",
-        "    # Add batch to GPU\n",
-        "        batch = tuple(t.to(device) for t in batch)\n",
-        "\n",
-        "        # Unpack the inputs from the dataloader\n",
-        "        b_input_ids, b_input_mask = batch\n",
-        "\n",
-        "        # Telling the model not to compute or store gradients, saving memory and\n",
-        "        # speeding up prediction\n",
-        "        with torch.no_grad():\n",
-        "            # Forward pass, calculate logit predictions\n",
-        "            outputs = model(b_input_ids, token_type_ids=None,\n",
-        "                            attention_mask=b_input_mask)\n",
-        "\n",
-        "        logits = outputs[0]\n",
-        "        #print(logits)\n",
-        "\n",
-        "        # Move logits and labels to CPU ???\n",
-        "        logits = logits.detach().cpu().numpy()\n",
-        "        #print(logits)\n",
-        "\n",
-        "        # Store predictions and true labels\n",
-        "        predictions_test.append(logits)\n",
-        "\n",
-        "        pred_labels = []\n",
-        "        \n",
-        "        for i in range(len(predictions_test)):\n",
-        "            # The predictions for this batch are a 2-column ndarray (one column for \"0\"\n",
-        "            # and one column for \"1\"). Pick the label with the highest value and turn this\n",
-        "            # in to a list of 0s and 1s.\n",
-        "            pred_labels_i = np.argmax(predictions_test[i], axis=1).flatten()\n",
-        "            pred_labels.append(pred_labels_i)\n",
-        "\n",
-        "    pred_labels_ += [item for sublist in pred_labels for item in sublist]\n",
-        "    return pred_labels_"
-      ]
-    },
     {
       "cell_type": "code",
       "execution_count": 16,
@@ -482,26 +524,13 @@
         "data_loader = generate_dataloader(tokenizer, data_LGE)"
       ]
     },
-    {
-      "cell_type": "markdown",
-      "metadata": {},
-      "source": [
-        "\n",
-        "https://discuss.huggingface.co/t/an-efficient-way-of-loading-a-model-that-was-saved-with-torch-save/9814\n",
-        "\n",
-        "https://github.com/huggingface/transformers/issues/2094\n"
-      ]
-    },
     {
       "cell_type": "code",
       "execution_count": 18,
       "metadata": {},
       "outputs": [],
       "source": [
-        "#model = torch.load(model_path, map_location=torch.device('mps'))\n",
-        "#model.load_state_dict(torch.load(model_path, map_location=torch.device('mps')))\n",
-        "\n",
-        "model = BertForSequenceClassification.from_pretrained(model_path).to(\"mps\") #.to(\"cuda\")"
+        "model = BertForSequenceClassification.from_pretrained(model_path).to(gpu_name) #.to(\"cuda\")"
       ]
     },
     {
@@ -519,344 +548,14 @@
         "pred = predict(model, data_loader, device)"
       ]
     },
-    {
-      "cell_type": "code",
-      "execution_count": 20,
-      "metadata": {},
-      "outputs": [
-        {
-          "data": {
-            "text/plain": [
-              "[15,\n",
-              " 6,\n",
-              " 16,\n",
-              " 15,\n",
-              " 17,\n",
-              " 10,\n",
-              " 17,\n",
-              " 16,\n",
-              " 19,\n",
-              " 35,\n",
-              " 15,\n",
-              " 26,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 2,\n",
-              " 2,\n",
-              " 17,\n",
-              " 6,\n",
-              " 32,\n",
-              " 17,\n",
-              " 30,\n",
-              " 16,\n",
-              " 32,\n",
-              " 15,\n",
-              " 35,\n",
-              " 15,\n",
-              " 23,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 17,\n",
-              " 15,\n",
-              " 16,\n",
-              " 3,\n",
-              " 17,\n",
-              " 17,\n",
-              " 16,\n",
-              " 4,\n",
-              " 15,\n",
-              " 17,\n",
-              " 19,\n",
-              " 16,\n",
-              " 35,\n",
-              " 3,\n",
-              " 17,\n",
-              " 5,\n",
-              " 15,\n",
-              " 16,\n",
-              " 16,\n",
-              " 15,\n",
-              " 16,\n",
-              " 6,\n",
-              " 16,\n",
-              " 5,\n",
-              " 16,\n",
-              " 15,\n",
-              " 28,\n",
-              " 16,\n",
-              " 17,\n",
-              " 10,\n",
-              " 15,\n",
-              " 15,\n",
-              " 32,\n",
-              " 15,\n",
-              " 17,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 12,\n",
-              " 15,\n",
-              " 18,\n",
-              " 15,\n",
-              " 35,\n",
-              " 26,\n",
-              " 16,\n",
-              " 16,\n",
-              " 15,\n",
-              " 5,\n",
-              " 15,\n",
-              " 15,\n",
-              " 5,\n",
-              " 17,\n",
-              " 15,\n",
-              " 17,\n",
-              " 35,\n",
-              " 15,\n",
-              " 16,\n",
-              " 16,\n",
-              " 17,\n",
-              " 2,\n",
-              " 17,\n",
-              " 15,\n",
-              " 16,\n",
-              " 23,\n",
-              " 16,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 16,\n",
-              " 6,\n",
-              " 15,\n",
-              " 35,\n",
-              " 15,\n",
-              " 32,\n",
-              " 16,\n",
-              " 6,\n",
-              " 16,\n",
-              " 23,\n",
-              " 36,\n",
-              " 5,\n",
-              " 35,\n",
-              " 3,\n",
-              " 3,\n",
-              " 3,\n",
-              " 16,\n",
-              " 17,\n",
-              " 2,\n",
-              " 15,\n",
-              " 5,\n",
-              " 17,\n",
-              " 16,\n",
-              " 15,\n",
-              " 17,\n",
-              " 6,\n",
-              " 15,\n",
-              " 16,\n",
-              " 10,\n",
-              " 16,\n",
-              " 15,\n",
-              " 35,\n",
-              " 17,\n",
-              " 15,\n",
-              " 15,\n",
-              " 6,\n",
-              " 28,\n",
-              " 16,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 16,\n",
-              " 5,\n",
-              " 15,\n",
-              " 21,\n",
-              " 5,\n",
-              " 1,\n",
-              " 7,\n",
-              " 16,\n",
-              " 15,\n",
-              " 17,\n",
-              " 23,\n",
-              " 15,\n",
-              " 5,\n",
-              " 0,\n",
-              " 10,\n",
-              " 16,\n",
-              " 16,\n",
-              " 15,\n",
-              " 16,\n",
-              " 15,\n",
-              " 15,\n",
-              " 3,\n",
-              " 3,\n",
-              " 17,\n",
-              " 36,\n",
-              " 16,\n",
-              " 15,\n",
-              " 12,\n",
-              " 6,\n",
-              " 15,\n",
-              " 4,\n",
-              " 16,\n",
-              " 16,\n",
-              " 26,\n",
-              " 15,\n",
-              " 15,\n",
-              " 32,\n",
-              " 15,\n",
-              " 10,\n",
-              " 15,\n",
-              " 5,\n",
-              " 26,\n",
-              " 5,\n",
-              " 15,\n",
-              " 15,\n",
-              " 26,\n",
-              " 15,\n",
-              " 35,\n",
-              " 15,\n",
-              " 16,\n",
-              " 16,\n",
-              " 15,\n",
-              " 6,\n",
-              " 16,\n",
-              " 12,\n",
-              " 16,\n",
-              " 28,\n",
-              " 16,\n",
-              " 15,\n",
-              " 15,\n",
-              " 16,\n",
-              " 6,\n",
-              " 10,\n",
-              " 16,\n",
-              " 15,\n",
-              " 15,\n",
-              " 16,\n",
-              " 16,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 5,\n",
-              " 16,\n",
-              " 16,\n",
-              " 17,\n",
-              " 15,\n",
-              " 16,\n",
-              " 35,\n",
-              " 16,\n",
-              " 16,\n",
-              " 15,\n",
-              " 6,\n",
-              " 29,\n",
-              " 16,\n",
-              " 15,\n",
-              " 5,\n",
-              " 5,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 16,\n",
-              " 16,\n",
-              " 15,\n",
-              " 15,\n",
-              " 31,\n",
-              " 16,\n",
-              " 15,\n",
-              " 16,\n",
-              " 15,\n",
-              " 6,\n",
-              " 16,\n",
-              " 3,\n",
-              " 15,\n",
-              " 2,\n",
-              " 15,\n",
-              " 15,\n",
-              " 28,\n",
-              " 17,\n",
-              " 15,\n",
-              " 15,\n",
-              " 16,\n",
-              " 15,\n",
-              " 15,\n",
-              " 10,\n",
-              " 15,\n",
-              " 5,\n",
-              " 16,\n",
-              " 15,\n",
-              " 15,\n",
-              " 17,\n",
-              " 15,\n",
-              " 5,\n",
-              " 15,\n",
-              " 3,\n",
-              " 15,\n",
-              " 2,\n",
-              " 15,\n",
-              " 15,\n",
-              " 6,\n",
-              " 15,\n",
-              " 28,\n",
-              " 15,\n",
-              " 6,\n",
-              " 15,\n",
-              " 32,\n",
-              " 16,\n",
-              " 15,\n",
-              " 2,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 16,\n",
-              " 17,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 16,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 35,\n",
-              " 15,\n",
-              " 15,\n",
-              " 35,\n",
-              " 16,\n",
-              " 28,\n",
-              " 15,\n",
-              " 15,\n",
-              " 15,\n",
-              " 5,\n",
-              " 15,\n",
-              " 15,\n",
-              " 19,\n",
-              " 15]"
-            ]
-          },
-          "execution_count": 20,
-          "metadata": {},
-          "output_type": "execute_result"
-        }
-      ],
-      "source": [
-        "pred"
-      ]
-    },
     {
       "cell_type": "code",
       "execution_count": 22,
       "metadata": {},
       "outputs": [],
       "source": [
-        "import pickle \n",
         "encoder_filename = \"models/label_encoder.pkl\"\n",
-        "with open(path+encoder_filename, 'rb') as file:\n",
+        "with open(path + encoder_filename, 'rb') as file:\n",
         "      encoder = pickle.load(file)"
       ]
     },
@@ -875,7 +574,7 @@
       "metadata": {},
       "outputs": [],
       "source": [
-        "df_LGE['class_bert'] = p2"
+        "df_LGE['domain'] = p2"
       ]
     },
     {
@@ -1591,7 +1290,8 @@
       "metadata": {},
       "outputs": [],
       "source": [
-        "df_LGE.to_csv(path + \"reports/classification_LGE.tsv\", sep=\"\\t\")"
+        "filepath = path + \"results_LGE/LGE-metadata-withContent.csv\"\n",
+        "df_LGE.to_csv(filepath, sep=\"\\,\")"
       ]
     },
     {
@@ -1599,7 +1299,11 @@
       "execution_count": null,
       "metadata": {},
       "outputs": [],
-      "source": []
+      "source": [
+        "df_LGE.drop(columns=['content'], inplace=True)\n",
+        "filepath = path + \"results_LGE/LGE-metadata.csv\"\n",
+        "df_LGE.to_csv(filepath, sep=\"\\,\")"
+      ]
     }
   ],
   "metadata": {
-- 
GitLab