From eaabc9e72b649b2209d6f9ec0dc319b972202c0b Mon Sep 17 00:00:00 2001
From: Arthur BATEL <arthur.batel@insa-lyon.fr>
Date: Sat, 20 Apr 2024 16:21:49 +0200
Subject: [PATCH] new metrics

---
 code/binary_bpr/BPR_model.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/code/binary_bpr/BPR_model.py b/code/binary_bpr/BPR_model.py
index 63431e4..3e202ea 100644
--- a/code/binary_bpr/BPR_model.py
+++ b/code/binary_bpr/BPR_model.py
@@ -5,7 +5,7 @@ from torch.utils.data import DataLoader, TensorDataset
 import numpy as np
 import torch.nn.functional as F
 import pandas as pd
-from sklearn.metrics import roc_auc_score
+from sklearn.metrics import roc_auc_score, recall_score, f1_score
 from sklearn.metrics import mean_squared_error
 from sklearn.metrics import accuracy_score
 from sklearn.metrics import precision_score
@@ -130,7 +130,10 @@ class BPRModel(nn.Module):
                 if epoch - best_ite >= quit_delta:
                     break
             elif epoch % 5 == 0 :
-                print("[Epoch %d] loss: %.6f" % (epoch, loss.item()))
+                try :
+                    print("[Epoch %d] loss: %.6f" % (epoch, loss.item()))
+                except UnboundLocalError as e:
+                    print(e)
         return acc
        
     # Evaluate the model
@@ -184,4 +187,4 @@ class BPRModel(nn.Module):
         mse1 = mean_squared_error(all_labels, all_predictions)
         # Compute AUC for the entire dataset
         auc = roc_auc_score(all_labels, all_predictions)
-        return (all_labels==all_decisions).astype(int),accuracy_score(all_labels, all_decisions) , users, auc, np.sqrt(mse1) 
+        return (all_labels==all_decisions).astype(int),accuracy_score(all_labels, all_decisions) , users, auc, np.sqrt(mse1) ,precision_score(all_labels, all_decisions),recall_score(all_labels, all_decisions),f1_score(all_labels, all_decisions)
-- 
GitLab