Skip to content
Snippets Groups Projects
BertFineTuning_.ipynb 48.7 KiB
Newer Older
        "### Report & Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d5n84N0xCfcU"
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "from sklearn.metrics import plot_confusion_matrix\n",
        "from sklearn.metrics import confusion_matrix\n",
        "from sklearn.metrics import classification_report\n",
        "import seaborn as sns"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "v4hXk-KjC-nq"
      },
      "source": [
        "report = classification_report( pred_labels_, true_labels_, output_dict = True)\n",
        "    \n",
        "accuracy = report['accuracy']\n",
        "weighted_avg = report['weighted avg']"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xETMy1L6DAa5"
      },
      "source": [
        "classes = [str(e) for e in encoder.transform(encoder.classes_)]\n",
        "classesName = encoder.classes_"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dPjV_5g8DDQy"
      },
      "source": [
        "precision = []\n",
        "recall = []\n",
        "f1 = []\n",
        "support = []\n",
        "dff = pd.DataFrame(columns= ['className', 'precision', 'recall', 'f1-score', 'support', 'FP', 'FN', 'TP', 'TN'])\n",
        "for c in classes:\n",
        "  precision.append(report[c]['precision'])\n",
        "  recall.append(report[c]['recall'])\n",
        "  f1.append(report[c]['f1-score'])\n",
        "  support.append(report[c]['support'])\n",
        "\n",
        "accuracy = report['accuracy']\n",
        "weighted_avg = report['weighted avg']\n",
        "cnf_matrix = confusion_matrix(true_labels_, pred_labels_)\n",
        "FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)\n",
        "FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)\n",
        "TP = np.diag(cnf_matrix)\n",
        "TN = cnf_matrix.sum() - (FP + FN + TP)\n",
        "\n",
        "dff['className'] = classesName\n",
        "dff['precision'] = precision\n",
        "dff['recall'] = recall\n",
        "dff['f1-score'] = f1\n",
        "dff['support'] = support\n",
        "dff['FP'] = FP\n",
        "dff['FN'] = FN\n",
        "dff['TP'] = TP\n",
        "dff['TN'] = TN\n",
        "  \n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vslzi9bHDKcv"
      },
      "source": [
        "print(weighted_avg)\n",
        "print(accuracy)\n",
        "print(dff)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}