From eaabc9e72b649b2209d6f9ec0dc319b972202c0b Mon Sep 17 00:00:00 2001 From: Arthur BATEL <arthur.batel@insa-lyon.fr> Date: Sat, 20 Apr 2024 16:21:49 +0200 Subject: [PATCH] new metrics --- code/binary_bpr/BPR_model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/code/binary_bpr/BPR_model.py b/code/binary_bpr/BPR_model.py index 63431e4..3e202ea 100644 --- a/code/binary_bpr/BPR_model.py +++ b/code/binary_bpr/BPR_model.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader, TensorDataset import numpy as np import torch.nn.functional as F import pandas as pd -from sklearn.metrics import roc_auc_score +from sklearn.metrics import roc_auc_score, recall_score, f1_score from sklearn.metrics import mean_squared_error from sklearn.metrics import accuracy_score from sklearn.metrics import precision_score @@ -130,7 +130,10 @@ class BPRModel(nn.Module): if epoch - best_ite >= quit_delta: break elif epoch % 5 == 0 : - print("[Epoch %d] loss: %.6f" % (epoch, loss.item())) + try : + print("[Epoch %d] loss: %.6f" % (epoch, loss.item())) + except UnboundLocalError as e: + print(e) return acc # Evaluate the model @@ -184,4 +187,4 @@ class BPRModel(nn.Module): mse1 = mean_squared_error(all_labels, all_predictions) # Compute AUC for the entire dataset auc = roc_auc_score(all_labels, all_predictions) - return (all_labels==all_decisions).astype(int),accuracy_score(all_labels, all_decisions) , users, auc, np.sqrt(mse1) + return (all_labels==all_decisions).astype(int),accuracy_score(all_labels, all_decisions) , users, auc, np.sqrt(mse1) ,precision_score(all_labels, all_decisions),recall_score(all_labels, all_decisions),f1_score(all_labels, all_decisions) -- GitLab