Skip to content
Snippets Groups Projects
Unverified Commit 6528157c authored by Quoc-Tuan Truong's avatar Quoc-Tuan Truong Committed by GitHub
Browse files

Improve efficiency of KNN methods (#331)

parent ad2d1ff5
No related branches found
No related tags found
No related merge requests found
......@@ -23,57 +23,70 @@ from ..recommender import Recommender
from ...exception import ScoreException
from ...utils import get_rng
from ...utils.fast_sparse_funcs import inplace_csr_row_normalize_l2
from .similarity import compute_similarity
from .similarity import compute_similarity, compute_score, compute_score_single
EPS = 1e-8
SIMILARITIES = ["cosine", "pearson"]
WEIGHTING_OPTIONS = ["idf", "bm25"]
def _mean_centered(csr_mat):
def _mean_centered(ui_mat):
"""Subtract every rating values with mean value of the corresponding rows"""
mean_arr = np.zeros(csr_mat.shape[0])
for i in range(csr_mat.shape[0]):
start_idx, end_idx = csr_mat.indptr[i : i + 2]
mean_arr[i] = np.mean(csr_mat.data[start_idx:end_idx])
csr_mat.data[start_idx:end_idx] -= mean_arr[i]
mean_arr = np.zeros(ui_mat.shape[0])
for i in range(ui_mat.shape[0]):
start_idx, end_idx = ui_mat.indptr[i : i + 2]
mean_arr[i] = np.mean(ui_mat.data[start_idx:end_idx])
row_data = ui_mat.data[start_idx:end_idx]
row_data -= mean_arr[i]
row_data[row_data == 0] = EPS
ui_mat.data[start_idx:end_idx] = row_data
return ui_mat, mean_arr
def _amplify(ui_mat, alpha=1.0):
"""Exponentially amplify values of similarity matrix"""
if alpha == 1.0:
return ui_mat
return csr_mat, mean_arr
for i, w in enumerate(ui_mat.data):
ui_mat.data[i] = w ** alpha if w > 0 else -(-w) ** alpha
return ui_mat
def _tfidf_weight(csr_mat):
"""Weight the matrix with TF-IDF"""
def _idf_weight(ui_mat):
"""Weight the matrix Inverse Document (Item) Frequency"""
X = coo_matrix(ui_mat)
# calculate IDF
N = float(csr_mat.shape[1])
idf = np.log(N) - np.log1p(np.bincount(csr_mat.indices))
N = float(X.shape[0])
idf = np.log(N / np.bincount(X.col))
# apply TF-IDF adjustment
csr_mat.data *= np.sqrt(idf[csr_mat.indices])
return csr_mat
weights = idf[ui_mat.indices] + EPS
return weights
def _bm25_weight(csr_mat):
def _bm25_weight(ui_mat):
"""Weight the matrix with BM25 algorithm"""
K1 = 1.2
B = 0.8
# calculate IDF
N = float(csr_mat.shape[1])
idf = np.log(N) - np.log1p(np.bincount(csr_mat.indices))
X = coo_matrix(ui_mat)
X.data = np.ones_like(X.data)
# calculate length_norm per document
row_sums = np.ravel(csr_mat.sum(axis=1))
N = float(X.shape[0])
idf = np.log(N / np.bincount(X.col))
# calculate length_norm per document (user)
row_sums = np.ravel(X.sum(axis=1))
average_length = row_sums.mean()
length_norm = (1.0 - B) + B * row_sums / average_length
# weight matrix rows by BM25
row_counts = np.ediff1d(csr_mat.indptr)
row_inds = np.repeat(np.arange(csr_mat.shape[0]), row_counts)
weights = (
(K1 + 1.0) / (K1 * length_norm[row_inds] + csr_mat.data) * idf[csr_mat.indices]
)
csr_mat.data *= np.sqrt(weights)
return csr_mat
# bm25 weights
weights = (K1 + 1.0) / (K1 * length_norm[X.row] + X.data) * idf[X.col] + EPS
return weights
class UserKNN(Recommender):
......@@ -90,8 +103,12 @@ class UserKNN(Recommender):
similarity: str, optional, default: 'cosine'
The similarity measurement. Supported types: ['cosine', 'pearson']
mean_centered: bool, optional, default: False
Whether values of the user-item rating matrix will be centered by the mean
of their corresponding rows (mean rating of each user).
weighting: str, optional, default: None
The option for re-weighting the rating matrix. Supported types: [tf-idf', 'bm25'].
The option for re-weighting the rating matrix. Supported types: ['idf', 'bm25'].
If None, no weighting is applied.
amplify: float, optional, default: 1.0
......@@ -110,14 +127,12 @@ class UserKNN(Recommender):
* Aggarwal, C. C. (2016). Recommender systems (Vol. 1). Cham: Springer International Publishing.
"""
SIMILARITIES = ["cosine", "pearson"]
WEIGHTING_OPTIONS = ["tf-idf", "bm25"]
def __init__(
self,
name="UserKNN",
k=20,
similarity="cosine",
mean_centered=False,
weighting=None,
amplify=1.0,
num_threads=0,
......@@ -128,19 +143,20 @@ class UserKNN(Recommender):
super().__init__(name=name, trainable=trainable, verbose=verbose)
self.k = k
self.similarity = similarity
self.mean_centered = mean_centered
self.weighting = weighting
self.amplify = amplify
self.seed = seed
self.rng = get_rng(seed)
if self.similarity not in self.SIMILARITIES:
if self.similarity not in SIMILARITIES:
raise ValueError(
"Invalid similarity choice, supported {}".format(self.SIMILARITIES)
"Invalid similarity choice, supported {}".format(SIMILARITIES)
)
if self.weighting is not None and self.weighting not in self.WEIGHTING_OPTIONS:
if self.weighting is not None and self.weighting not in WEIGHTING_OPTIONS:
raise ValueError(
"Invalid weighting choice, supported {}".format(self.WEIGHTING_OPTIONS)
"Invalid weighting choice, supported {}".format(WEIGHTING_OPTIONS)
)
if seed is not None:
......@@ -169,25 +185,28 @@ class UserKNN(Recommender):
self.ui_mat = self.train_set.matrix.copy()
self.mean_arr = np.zeros(self.ui_mat.shape[0])
if self.train_set.min_rating != self.train_set.max_rating: # explicit feedback
self.ui_mat, self.mean_arr = _mean_centered(self.ui_mat)
if self.similarity == "cosine":
weight_mat = self.train_set.matrix.copy()
elif self.similarity == "pearson":
if self.mean_centered or self.similarity == "pearson":
weight_mat = self.ui_mat.copy()
else:
weight_mat = self.train_set.matrix.copy()
# rating matrix re-weighting
if self.weighting == "tf-idf":
weight_mat = _tfidf_weight(weight_mat)
# re-weighting
if self.weighting == "idf":
weight_mat.data *= np.sqrt(_idf_weight(self.train_set.matrix))
elif self.weighting == "bm25":
weight_mat = _bm25_weight(weight_mat)
weight_mat.data *= np.sqrt(_bm25_weight(self.train_set.matrix))
# only need item-user matrix for prediction
self.iu_mat = self.ui_mat.T.tocsr()
del self.ui_mat
inplace_csr_row_normalize_l2(weight_mat)
self.sim_mat = compute_similarity(
weight_mat, k=self.k, num_threads=self.num_threads, verbose=self.verbose
).power(self.amplify)
)
self.sim_mat = _amplify(self.sim_mat, self.amplify)
return self
......@@ -218,16 +237,30 @@ class UserKNN(Recommender):
"Can't make score prediction for (item_id=%d)" % item_idx
)
user_weights = self.sim_mat[user_idx]
user_weights = user_weights / (
np.abs(user_weights).sum() + EPS
) # normalize for rating prediction
known_item_scores = (
self.mean_arr[user_idx] + user_weights.dot(self.ui_mat).A.ravel()
)
if item_idx is not None:
return known_item_scores[item_idx]
weighted_avg = compute_score_single(
True,
self.sim_mat[user_idx].A.ravel(),
self.iu_mat.indptr[item_idx],
self.iu_mat.indptr[item_idx + 1],
self.iu_mat.indices,
self.iu_mat.data,
k=self.k,
)
return self.mean_arr[user_idx] + weighted_avg
weighted_avg = np.zeros(self.train_set.num_items)
compute_score(
True,
self.sim_mat[user_idx].A.ravel(),
self.iu_mat.indptr,
self.iu_mat.indices,
self.iu_mat.data,
k=self.k,
num_threads=self.num_threads,
output=weighted_avg,
)
known_item_scores = self.mean_arr[user_idx] + weighted_avg
return known_item_scores
......@@ -244,10 +277,14 @@ class ItemKNN(Recommender):
The number of nearest neighbors.
similarity: str, optional, default: 'cosine'
The similarity measurement. Supported types: ['cosine', 'adjusted', 'pearson']
The similarity measurement. Supported types: ['cosine', 'pearson']
mean_centered: bool, optional, default: False
Whether values of the user-item rating matrix will be centered by the mean
of their corresponding rows (mean rating of each user).
weighting: str, optional, default: None
The option for re-weighting the rating matrix. Supported types: [tf-idf', 'bm25'].
The option for re-weighting the rating matrix. Supported types: ['idf', 'bm25'].
If None, no weighting is applied.
amplify: float, optional, default: 1.0
......@@ -266,14 +303,12 @@ class ItemKNN(Recommender):
* Aggarwal, C. C. (2016). Recommender systems (Vol. 1). Cham: Springer International Publishing.
"""
SIMILARITIES = ["cosine", "adjusted", "pearson"]
WEIGHTING_OPTIONS = ["tf-idf", "bm25"]
def __init__(
self,
name="ItemKNN",
k=20,
similarity="cosine",
mean_centered=False,
weighting=None,
amplify=1.0,
num_threads=0,
......@@ -284,19 +319,20 @@ class ItemKNN(Recommender):
super().__init__(name=name, trainable=trainable, verbose=verbose)
self.k = k
self.similarity = similarity
self.mean_centered = mean_centered
self.weighting = weighting
self.amplify = amplify
self.seed = seed
self.rng = get_rng(seed)
if self.similarity not in self.SIMILARITIES:
if self.similarity not in SIMILARITIES:
raise ValueError(
"Invalid similarity choice, supported {}".format(self.SIMILARITIES)
"Invalid similarity choice, supported {}".format(SIMILARITIES)
)
if self.weighting is not None and self.weighting not in self.WEIGHTING_OPTIONS:
if self.weighting is not None and self.weighting not in WEIGHTING_OPTIONS:
raise ValueError(
"Invalid weighting choice, supported {}".format(self.WEIGHTING_OPTIONS)
"Invalid weighting choice, supported {}".format(WEIGHTING_OPTIONS)
)
if seed is not None:
......@@ -325,32 +361,29 @@ class ItemKNN(Recommender):
self.ui_mat = self.train_set.matrix.copy()
self.mean_arr = np.zeros(self.ui_mat.shape[0])
explicit_feedback = self.train_set.min_rating != self.train_set.max_rating
if explicit_feedback:
if self.train_set.min_rating != self.train_set.max_rating: # explicit feedback
self.ui_mat, self.mean_arr = _mean_centered(self.ui_mat)
if self.similarity == "cosine":
if self.mean_centered:
weight_mat = self.ui_mat.copy()
else:
weight_mat = self.train_set.matrix.copy()
elif self.similarity == "adjusted":
weight_mat = self.ui_mat.copy() # mean-centered by rows
elif self.similarity == "pearson" and explicit_feedback:
weight_mat, _ = _mean_centered(
self.train_set.matrix.T.tocsr()
) # mean-centered by columns
if self.similarity == "pearson": # centered by columns
weight_mat, _ = _mean_centered(weight_mat.T.tocsr())
weight_mat = weight_mat.T.tocsr()
# rating matrix re-weighting
if self.weighting == "tf-idf":
weight_mat = _tfidf_weight(weight_mat)
# re-weighting
if self.weighting == "idf":
weight_mat.data *= np.sqrt(_idf_weight(self.train_set.matrix))
elif self.weighting == "bm25":
weight_mat = _bm25_weight(weight_mat)
weight_mat.data *= np.sqrt(_bm25_weight(self.train_set.matrix))
weight_mat = weight_mat.T.tocsr()
inplace_csr_row_normalize_l2(weight_mat)
self.sim_mat = compute_similarity(
weight_mat, k=self.k, num_threads=self.num_threads, verbose=self.verbose
).power(self.amplify)
)
self.sim_mat = _amplify(self.sim_mat, self.amplify)
return self
......@@ -381,13 +414,27 @@ class ItemKNN(Recommender):
"Can't make score prediction for (item_id=%d)" % item_idx
)
user_profile = self.ui_mat[user_idx]
known_item_scores = self.mean_arr[user_idx] + (
user_profile.dot(self.sim_mat).A.ravel()
/ (np.abs(self.sim_mat).sum(axis=0).A.ravel() + EPS)
)
if item_idx is not None:
return known_item_scores[item_idx]
return known_item_scores
weighted_avg = compute_score_single(
False,
self.ui_mat[user_idx].A.ravel(),
self.sim_mat.indptr[item_idx],
self.sim_mat.indptr[item_idx + 1],
self.sim_mat.indices,
self.sim_mat.data,
k=self.k,
)
return self.mean_arr[user_idx] + weighted_avg
weighted_avg = np.zeros(self.train_set.num_items)
compute_score(
False,
self.ui_mat[user_idx].A.ravel(),
self.sim_mat.indptr,
self.sim_mat.indices,
self.sim_mat.data,
k=self.k,
num_threads=self.num_threads,
output=weighted_avg,
)
return self.mean_arr[user_idx] + weighted_avg
This diff is collapsed.
......@@ -37,24 +37,19 @@ struct TopK
std::greater<std::pair<Value, Index>> heap_order;
};
/** A utility class to multiply rows of a sparse matrix
Implements the sparse matrix multiplication algorithm
described in the paper 'Sparse Matrix Multiplication Package (SMMP)'
http://www.i2m.univ-amu.fr/~bradji/multp_sparse.pdf
*/
template <typename Index, typename Value>
class SparseMatrixMultiplier
class SparseNeighbors
{
public:
explicit SparseMatrixMultiplier(Index count)
: sums(count, 0), nonzeros(count, -1), head(-2), length(0)
explicit SparseNeighbors(Index count)
: weights(count, 0), scores(count, 0), nonzeros(count, -1), head(-2), length(0)
{
}
/** Adds value to the element at index */
void add(Index index, Value value)
void set(Index index, Value weight, Value score)
{
sums[index] += value;
weights[index] = weight;
scores[index] = score;
if (nonzeros[index] == -1)
{
......@@ -64,18 +59,18 @@ public:
}
}
/** Calls a function once per non-zero entry, also clears state for next run*/
template <typename Function>
void foreach (Function &f)
{ // NOLINT(*)
for (int i = 0; i < length; ++i)
{
Index index = head;
f(index, sums[index]);
f(scores[index], weights[index]);
// clear up memory and advance linked list
head = nonzeros[head];
sums[index] = 0;
weights[index] = 0;
scores[index] = 0;
nonzeros[index] = -1;
}
......@@ -85,11 +80,13 @@ public:
Index nnz() const { return length; }
std::vector<Value> sums;
std::vector<Value> weights;
std::vector<Value> scores;
protected:
std::vector<Index> nonzeros;
Index head, length;
};
} // namespace cornac_knn
#endif // CORNAC_SIMILARITY_H_
\ No newline at end of file
......@@ -19,10 +19,11 @@ import cython
from cython cimport floating, integral
from cython.operator import dereference
from cython.parallel import parallel, prange
from libc.math cimport sqrt
from libc.math cimport sqrt, fabs
from libcpp cimport bool
from libcpp.vector cimport vector
from libcpp.utility cimport pair
from libc.stdlib cimport abort, malloc, free
import threading
import numpy as np
......@@ -38,11 +39,12 @@ cdef extern from "similarity.h" namespace "cornac_knn" nogil:
TopK(size_t K)
vector[pair[Value, Index]] results
cdef cppclass SparseMatrixMultiplier[Index, Value]:
SparseMatrixMultiplier(Index n_rows)
void add(Index index, Value value)
cdef cppclass SparseNeighbors[Index, Value]:
SparseNeighbors(Index max_neighbors)
void set(Index index, Value weight, Value score)
void foreach[Function](Function & f)
vector[Value] sums
vector[Value] weights
vector[Value] scores
@cython.boundscheck(False)
......@@ -54,53 +56,146 @@ def compute_similarity(data_mat, unsigned int k=20, unsigned int num_threads=0,
cdef int n_rows = row_mat.shape[0]
cdef int r, c, i, j
cdef double w
cdef double w, denom
cdef int[:] row_indptr = row_mat.indptr, row_indices = row_mat.indices
cdef double[:] row_data = row_mat.data
cdef int[:] col_indptr = col_mat.indptr, col_indices = col_mat.indices
cdef double[:] col_data = col_mat.data
cdef SparseMatrixMultiplier[int, double] * neighbours
cdef TopK[int, double] * topk
cdef pair[double, int] result
# holds triples of output similarity matrix
cdef double[:] values = np.zeros(n_rows * k)
cdef long[:] rows = np.zeros(n_rows * k, dtype=int)
cdef long[:] cols = np.zeros(n_rows * k, dtype=int)
cdef double[:, :] sim_mat = np.zeros((n_rows, n_rows))
cdef double * denom1
cdef double * denom2
progress = tqdm(total=n_rows, disable=not verbose)
with nogil, parallel(num_threads=num_threads):
# allocate memory per thread
neighbours = new SparseMatrixMultiplier[int, double](n_rows)
topk = new TopK[int, double](k)
try:
for r in prange(n_rows, schedule='guided'):
for i in range(row_indptr[r], row_indptr[r + 1]):
c = row_indices[i]
w = row_data[i]
for j in range(col_indptr[c], col_indptr[c + 1]):
neighbours.add(col_indices[j], col_data[j] * w)
topk.results.clear()
neighbours.foreach(dereference(topk))
i = k * r
for result in topk.results:
rows[i] = r
cols[i] = result.second
values[i] = result.first
i = i + 1
with gil:
progress.update(1)
finally:
del neighbours
del topk
for r in prange(n_rows, schedule='guided'):
denom1 = <double *> malloc(sizeof(double) * n_rows)
denom2 = <double *> malloc(sizeof(double) * n_rows)
if denom1 is NULL or denom2 is NULL:
abort()
for i in range(n_rows):
denom1[i] = 0
denom2[i] = 0
for i in range(row_indptr[r], row_indptr[r + 1]):
c, w = row_indices[i], row_data[i]
for j in range(col_indptr[c], col_indptr[c + 1]): # neighbors
sim_mat[r, col_indices[j]] += col_data[j] * w
if w != 0 and col_data[j] != 0:
denom1[col_indices[j]] += w * w
denom2[col_indices[j]] += col_data[j] * col_data[j]
for i in range(n_rows):
if sim_mat[r, i] != 0:
denom = sqrt(denom1[i]) * sqrt(denom2[i])
sim_mat[r, i] /= denom
free(denom1)
free(denom2)
with gil:
progress.update(1)
progress.close()
return csr_matrix((values, (rows, cols)), shape=(n_rows, n_rows))
\ No newline at end of file
sparse_sim_mat = csr_matrix(sim_mat)
del sim_mat
return sparse_sim_mat
@cython.boundscheck(False)
def compute_score_single(
bool user_mode,
floating[:] sim_arr,
int ptr1,
int ptr2,
int[:] indices,
floating[:] data,
int k
):
cdef int max_neighbors = sim_arr.shape[0]
cdef int nn, j
cdef double w, s, num, denom, output
cdef SparseNeighbors[int, double] * neighbours = new SparseNeighbors[int, double](max_neighbors)
cdef TopK[double, double] * topk = new TopK[double, double](k)
cdef pair[double, double] result
for j in range(ptr1, ptr2):
nn, s = indices[j], data[j]
if sim_arr[nn] != 0:
if user_mode:
neighbours.set(nn, sim_arr[nn], s)
else:
neighbours.set(nn, s, sim_arr[nn])
topk.results.clear()
neighbours.foreach(dereference(topk))
num = 0
denom = 0
for result in topk.results:
w = result.first
s = result.second
num = num + w * s
denom = denom + fabs(w)
output = num / (denom + 1e-8)
del topk
del neighbours
return output
@cython.boundscheck(False)
def compute_score(
bool user_mode,
floating[:] sim_arr,
int[:] indptr,
int[:] indices,
floating[:] data,
unsigned int k,
unsigned int num_threads,
floating[:] output
):
cdef int max_neighbors = sim_arr.shape[0]
cdef int n_items = output.shape[0]
cdef int nn, i, j
cdef double w, s, num, denom
cdef SparseNeighbors[int, double] * neighbours
cdef TopK[double, double] * topk
cdef pair[double, double] result
with nogil, parallel(num_threads=num_threads):
for i in prange(n_items, schedule='guided'):
# allocate memory per thread
neighbours = new SparseNeighbors[int, double](max_neighbors)
topk = new TopK[double, double](k)
for j in range(indptr[i], indptr[i + 1]):
nn, s = indices[j], data[j]
if sim_arr[nn] != 0:
if user_mode:
neighbours.set(nn, sim_arr[nn], s)
else:
neighbours.set(nn, s, sim_arr[nn])
topk.results.clear()
neighbours.foreach(dereference(topk))
num = 0
denom = 0
for result in topk.results:
w = result.first
s = result.second
num = num + w * s
denom = denom + fabs(w)
output[i] = num / (denom + 1e-8)
del neighbours
del topk
\ No newline at end of file
......@@ -19,58 +19,53 @@ from cornac.datasets import movielens
from cornac.eval_methods import RatioSplit
K = 50 # number of nearest neighbors
# Load ML-100K dataset
feedback = movielens.load_feedback(variant="100K")
# Define an evaluation method to split feedback into train and test sets
ratio_split = RatioSplit(
data=feedback,
test_size=0.2,
rating_threshold=4.0,
exclude_unknowns=True,
verbose=True,
seed=123,
data=feedback, test_size=0.2, exclude_unknowns=True, verbose=True, seed=123
)
# Comparing a few variants of KNN methods
user_knn_cosine = cornac.models.UserKNN(
k=20, similarity="cosine", amplify=1.0, name="UserKNN-Cosine"
)
# UserKNN methods
user_knn_cosine = cornac.models.UserKNN(k=K, similarity="cosine", name="UserKNN-Cosine")
user_knn_pearson = cornac.models.UserKNN(
k=20, similarity="pearson", amplify=1.0, name="UserKNN-Pearson"
)
user_knn_tfidf = cornac.models.UserKNN(
k=20, similarity="cosine", weighting="tf-idf", amplify=1.0, name="UserKNN-TFIDF"
k=K, similarity="pearson", name="UserKNN-Pearson"
)
user_knn_bm25 = cornac.models.UserKNN(
k=20, similarity="cosine", weighting="bm25", amplify=1.0, name="UserKNN-BM25"
user_knn_amp = cornac.models.UserKNN(
k=K, similarity="cosine", amplify=2.0, name="UserKNN-Amplified"
)
item_knn_cosine = cornac.models.ItemKNN(
k=20, similarity="cosine", amplify=1.0, name="ItemKNN-Cosine"
user_knn_idf = cornac.models.UserKNN(
k=K, similarity="cosine", weighting="idf", name="UserKNN-IDF"
)
item_knn_adjusted_cosine = cornac.models.ItemKNN(
k=20, similarity="adjusted", amplify=1.0, name="ItemKNN-AdjustedCosine"
user_knn_bm25 = cornac.models.UserKNN(
k=K, similarity="cosine", weighting="bm25", name="UserKNN-BM25"
)
# ItemKNN methods
item_knn_cosine = cornac.models.ItemKNN(k=K, similarity="cosine", name="ItemKNN-Cosine")
item_knn_pearson = cornac.models.ItemKNN(
k=20, similarity="pearson", amplify=1.0, name="ItemKNN-Pearson"
k=K, similarity="pearson", name="ItemKNN-Pearson"
)
item_knn_adjusted = cornac.models.ItemKNN(
k=K, similarity="cosine", mean_centered=True, name="ItemKNN-AdjustedCosine"
)
# Evaluation metrics
rmse = cornac.metrics.RMSE()
rec_20 = cornac.metrics.Recall(k=20)
# Put everything together into an experiment and run it
# Put everything together into an experiment
cornac.Experiment(
eval_method=ratio_split,
models=[
user_knn_cosine,
user_knn_pearson,
user_knn_tfidf,
user_knn_amp,
user_knn_idf,
user_knn_bm25,
item_knn_cosine,
item_knn_adjusted_cosine,
item_knn_pearson,
item_knn_adjusted,
],
metrics=[rmse, rec_20],
metrics=[cornac.metrics.RMSE()],
user_based=True,
).run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment