Skip to content
Snippets Groups Projects
Commit eaabc9e7 authored by Arthur Batel's avatar Arthur Batel
Browse files

new metrics

parent 3c1f027a
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ from torch.utils.data import DataLoader, TensorDataset ...@@ -5,7 +5,7 @@ from torch.utils.data import DataLoader, TensorDataset
import numpy as np import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
import pandas as pd import pandas as pd
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score, recall_score, f1_score
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score from sklearn.metrics import precision_score
...@@ -130,7 +130,10 @@ class BPRModel(nn.Module): ...@@ -130,7 +130,10 @@ class BPRModel(nn.Module):
if epoch - best_ite >= quit_delta: if epoch - best_ite >= quit_delta:
break break
elif epoch % 5 == 0 : elif epoch % 5 == 0 :
print("[Epoch %d] loss: %.6f" % (epoch, loss.item())) try :
print("[Epoch %d] loss: %.6f" % (epoch, loss.item()))
except UnboundLocalError as e:
print(e)
return acc return acc
# Evaluate the model # Evaluate the model
...@@ -184,4 +187,4 @@ class BPRModel(nn.Module): ...@@ -184,4 +187,4 @@ class BPRModel(nn.Module):
mse1 = mean_squared_error(all_labels, all_predictions) mse1 = mean_squared_error(all_labels, all_predictions)
# Compute AUC for the entire dataset # Compute AUC for the entire dataset
auc = roc_auc_score(all_labels, all_predictions) auc = roc_auc_score(all_labels, all_predictions)
return (all_labels==all_decisions).astype(int),accuracy_score(all_labels, all_decisions) , users, auc, np.sqrt(mse1) return (all_labels==all_decisions).astype(int),accuracy_score(all_labels, all_decisions) , users, auc, np.sqrt(mse1) ,precision_score(all_labels, all_decisions),recall_score(all_labels, all_decisions),f1_score(all_labels, all_decisions)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment