diff --git a/main.py b/main.py index a9aa6e448acb700621d9a96df4b7b8b7615a8491..d2227c60efc495671e168a4da0406ac5e77b2df6 100644 --- a/main.py +++ b/main.py @@ -215,7 +215,7 @@ def run_duo(args): def make_prediction_duo(model, data, f_name): y_pred = [] y_true = [] - + print('Building confusion matrix') # iterate over test data for imaer,imana, label in data: label = label.long() @@ -238,6 +238,8 @@ def make_prediction_duo(model, data, f_name): cf_matrix = confusion_matrix(y_true, y_pred) df_cm = pd.DataFrame(cf_matrix[:, None], index=[i for i in classes], columns=[i for i in classes]) + + print('Saving Confusion Matrix') plt.figure(figsize=(14, 9)) sn.heatmap(df_cm, annot=True) plt.savefig(f_name)