Skip to content
Snippets Groups Projects
Unverified Commit 37db3c96 authored by darrylong's avatar darrylong Committed by GitHub
Browse files

Enhance serving evaluation endpoints (#595)


* Include metric_user_results in evaluation response, added eval json endpoint

* Remove query from response

* Utilize mapped inversed user id map to get original id in response

* Update serving test case to remove 'query' and add 'user_result' in response

* simplify user ID mapping

* Combined evaluation and evaluation_json endpoints

* Updated abort responses to show plaintext instead of html

* Added unit test cases

* Updated error response for empty data

* Added unit tests for provided data evaluation

* Update app.py

* Update test_app.py

---------

Co-authored-by: default avatarQuoc-Tuan Truong <tqtg@users.noreply.github.com>
parent 92a94e38
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,7 @@ from cornac.eval_methods import BaseMethod
from cornac.metrics import *
try:
from flask import Flask, jsonify, request
from flask import Flask, jsonify, request, abort, make_response
except ImportError:
exit("Flask is required in order to serve models.\n" + "Run: pip3 install Flask")
......@@ -185,7 +185,6 @@ def add_feedback():
return jsonify(data), 200
# curl -X POST -H "Content-Type: application/json" -d '{"metrics": ["RMSE()", "NDCG(k=10)"]}' "http://localhost:8080/evaluate"
@app.route("/evaluate", methods=["POST"])
def evaluate():
global model, train_set, metric_classnames
......@@ -197,20 +196,59 @@ def evaluate():
return "Unable to evaluate. 'train_set' is not provided", 400
query = request.json
validate_query(query)
query_metrics = query.get("metrics")
rating_threshold = query.get("rating_threshold", 1.0)
exclude_unknowns = (
query.get("exclude_unknowns", "true").lower() == "true"
) # exclude unknown users/items by default, otherwise specified
if "data" in query:
data = query.get("data")
else:
data = []
data_fpath = "data/feedback.csv"
if os.path.exists(data_fpath):
reader = Reader()
data = reader.read(data_fpath, fmt="UIR", sep=",")
if not data:
response = make_response("No feedback has been provided so far. No data available to evaluate the model.")
response.status_code = 400
abort(response)
test_set = Dataset.build(
data,
fmt="UIR",
global_uid_map=train_set.uid_map,
global_iid_map=train_set.iid_map,
exclude_unknowns=exclude_unknowns,
)
return process_evaluation(test_set, query, exclude_unknowns)
def validate_query(query):
query_metrics = query.get("metrics")
if not query_metrics:
response = make_response("metrics is required")
response.status_code = 400
abort(response)
elif not isinstance(query_metrics, list):
response = make_response("metrics must be an array of metrics")
response.status_code = 400
abort(response)
def process_evaluation(test_set, query, exclude_unknowns):
global model, train_set
rating_threshold = query.get("rating_threshold", 1.0)
user_based = (
query.get("user_based", "true").lower() == "true"
) # user_based evaluation by default, otherwise specified
if query_metrics is None:
return "metrics is required", 400
elif not isinstance(query_metrics, list):
return "metrics must be an array of metrics", 400
query_metrics = query.get("metrics")
# organize metrics
metrics = []
......@@ -226,24 +264,6 @@ def evaluate():
rating_metrics, ranking_metrics = BaseMethod.organize_metrics(metrics)
# read data
data = []
data_fpath = "data/feedback.csv"
if os.path.exists(data_fpath):
reader = Reader()
data = reader.read(data_fpath, fmt="UIR", sep=",")
if not len(data):
raise ValueError("No data available to evaluate the model.")
test_set = Dataset.build(
data,
fmt="UIR",
global_uid_map=train_set.uid_map,
global_iid_map=train_set.iid_map,
exclude_unknowns=exclude_unknowns,
)
# evaluation
result = BaseMethod.eval(
model=model,
......@@ -258,10 +278,17 @@ def evaluate():
verbose=False,
)
# map user index back into the original user ID
metric_user_results = {}
for metric, user_results in result.metric_user_results.items():
metric_user_results[metric] = {
train_set.user_ids[int(k)]: v for k, v in user_results.items()
}
# response
response = {
"result": result.metric_avg_results,
"query": query,
"user_result": metric_user_results,
}
return jsonify(response), 200
......
......@@ -96,9 +96,10 @@ def test_evaluate_json(client):
response = client.post('/evaluate', json=json_data)
# assert response.content_type == 'application/json'
assert response.status_code == 200
assert len(response.json['query']['metrics']) == 2
assert 'RMSE' in response.json['result']
assert 'Recall@5' in response.json['result']
assert 'RMSE' in response.json['user_result']
assert 'Recall@5' in response.json['user_result']
def test_evalulate_incorrect_get(client):
......@@ -110,3 +111,52 @@ def test_evalulate_incorrect_post(client):
response = client.post('/evaluate')
assert response.status_code == 415 # bad request, expect json
def test_evaluate_missing_metrics(client):
json_data = {
'metrics': []
}
response = client.post('/evaluate', json=json_data)
assert response.status_code == 400
assert response.data == b'metrics is required'
def test_evaluate_not_list_metrics(client):
json_data = {
'metrics': 'RMSE()'
}
response = client.post('/evaluate', json=json_data)
assert response.status_code == 400
assert response.data == b'metrics must be an array of metrics'
def test_recommend_missing_uid(client):
response = client.get('/recommend?k=5')
assert response.status_code == 400
assert response.data == b'uid is required'
def test_evaluate_use_data(client):
json_data = {
'metrics': ['RMSE()', 'Recall(k=5)'],
'data': [['930', '795', 5], ['195', '795', 3]]
}
response = client.post('/evaluate', json=json_data)
# assert response.content_type == 'application/json'
assert response.status_code == 200
assert 'RMSE' in response.json['result']
assert 'Recall@5' in response.json['result']
assert 'RMSE' in response.json['user_result']
assert 'Recall@5' in response.json['user_result']
def test_evaluate_use_data_empty(client):
json_data = {
'metrics': ['RMSE()', 'Recall(k=5)'],
'data': []
}
response = client.post('/evaluate', json=json_data)
assert response.status_code == 400
assert response.data == b"No feedback has been provided so far. No data available to evaluate the model."
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment