diff --git a/main.py b/main.py index a4074a243a26c9b71bdecac735ea861dd5ea86b6..a9aa6e448acb700621d9a96df4b7b8b7615a8491 100644 --- a/main.py +++ b/main.py @@ -236,9 +236,9 @@ def make_prediction_duo(model, data, f_name): # Build confusion matrix cf_matrix = confusion_matrix(y_true, y_pred) - df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], + df_cm = pd.DataFrame(cf_matrix[:, None], index=[i for i in classes], columns=[i for i in classes]) - plt.figure(figsize=(12, 7)) + plt.figure(figsize=(14, 9)) sn.heatmap(df_cm, annot=True) plt.savefig(f_name)