Skip to content
Snippets Groups Projects
Unverified Commit 98ccb808 authored by Jaime Hieu Do's avatar Jaime Hieu Do Committed by GitHub
Browse files

Add PyTorch backend for MF (#546)


* add PyTorch backend for MF

* update example

---------

Co-authored-by: default avatartqtg <tuantq.vnu@gmail.com>
parent b44e19c9
No related branches found
No related tags found
No related merge requests found
from .recom_mf import MF
\ No newline at end of file
from .recom_mf import MF
# Copyright 2018 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# cython: language_level=3
import multiprocessing
cimport cython
from cython.parallel import prange
from cython cimport floating, integral
from libcpp cimport bool
from libc.math cimport abs
import numpy as np
cimport numpy as np
from tqdm.auto import trange
@cython.boundscheck(False)
@cython.wraparound(False)
def fit_sgd(integral[:] rid, integral[:] cid, floating[:] val,
floating[:, :] U, floating[:, :] V,
floating[:] Bu, floating[:] Bi,
long num_users, long num_items,
floating lr, floating reg, floating mu,
int max_iter, int num_threads,
bool use_bias, bool early_stop, bool verbose):
"""Fit the model parameters (U, V, Bu, Bi) with SGD"""
cdef:
long num_ratings = val.shape[0]
int num_factors = U.shape[1]
floating loss = 0
floating last_loss = 0
floating r, r_pred, error, u_f, i_f, delta_loss
integral u, i, f, j
floating * user
floating * item
progress = trange(max_iter, disable=not verbose)
for epoch in progress:
last_loss = loss
loss = 0
for j in prange(num_ratings, nogil=True, schedule='static', num_threads=num_threads):
u, i, r = rid[j], cid[j], val[j]
user, item = &U[u, 0], &V[i, 0]
# predict rating
r_pred = mu + Bu[u] + Bi[i]
for f in range(num_factors):
r_pred = r_pred + user[f] * item[f]
error = r - r_pred
loss += error * error
# update factors
for f in range(num_factors):
u_f, i_f = user[f], item[f]
user[f] += lr * (error * i_f - reg * u_f)
item[f] += lr * (error * u_f - reg * i_f)
# update biases
if use_bias:
Bu[u] += lr * (error - reg * Bu[u])
Bi[i] += lr * (error - reg * Bi[i])
loss = 0.5 * loss
progress.update(1)
progress.set_postfix({"loss": "%.2f" % loss})
delta_loss = loss - last_loss
if early_stop and abs(delta_loss) < 1e-5:
if verbose:
print('Early stopping, delta_loss = %.4f' % delta_loss)
break
progress.close()
if verbose:
print('Optimization finished!')
# Copyright 2018 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import torch
import torch.nn as nn
from tqdm.auto import trange
OPTIMIZER_DICT = {
"sgd": torch.optim.SGD,
"adam": torch.optim.Adam,
"rmsprop": torch.optim.RMSprop,
"adagrad": torch.optim.Adagrad,
}
class MF(nn.Module):
def __init__(
self,
u_factors,
i_factors,
u_biases,
i_biases,
use_bias,
global_mean,
dropout,
):
super(MF, self).__init__()
self.use_bias = use_bias
self.global_mean = global_mean
self.dropout = nn.Dropout(p=dropout)
self.u_factors = nn.Embedding(*u_factors.shape)
self.i_factors = nn.Embedding(*i_factors.shape)
self.u_factors.weight.data = torch.from_numpy(u_factors)
self.i_factors.weight.data = torch.from_numpy(i_factors)
if use_bias:
self.u_biases = nn.Embedding(*u_biases.shape)
self.i_biases = nn.Embedding(*i_biases.shape)
self.u_biases.weight.data = torch.from_numpy(u_biases)
self.i_biases.weight.data = torch.from_numpy(i_biases)
def forward(self, uids, iids):
ues = self.u_factors(uids)
ies = self.i_factors(iids)
preds = (self.dropout(ues) * self.dropout(ies)).sum(dim=1, keepdim=True)
if self.use_bias:
preds += self.u_biases(uids) + self.i_biases(iids) + self.global_mean
return preds.squeeze()
def learn(
model,
train_set,
n_epochs,
batch_size=256,
learning_rate=0.01,
reg=0.0,
verbose=True,
optimizer="sgd",
device=torch.device("cpu"),
):
model = model.to(device)
criteria = nn.MSELoss(reduction="sum")
optimizer = OPTIMIZER_DICT[optimizer](
params=model.parameters(), lr=learning_rate, weight_decay=reg
)
progress_bar = trange(1, n_epochs + 1, disable=not verbose)
for _ in progress_bar:
sum_loss = 0.0
count = 0
for batch_id, (u_batch, i_batch, r_batch) in enumerate(
train_set.uir_iter(batch_size, shuffle=True)
):
u_batch = torch.from_numpy(u_batch).to(device)
i_batch = torch.from_numpy(i_batch).to(device)
r_batch = torch.tensor(r_batch, dtype=torch.float).to(device)
preds = model(u_batch, i_batch)
loss = criteria(preds, r_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_loss += loss.data.item()
count += len(u_batch)
if batch_id % 10 == 0:
progress_bar.set_postfix(loss=(sum_loss / count))
......@@ -13,18 +13,10 @@
# limitations under the License.
# ============================================================================
# cython: language_level=3
import multiprocessing
cimport cython
from cython.parallel import prange
from cython cimport floating, integral
from libcpp cimport bool
from libc.math cimport abs
import numpy as np
cimport numpy as np
from tqdm.auto import trange
from ..recommender import Recommender
......@@ -35,7 +27,6 @@ from ...utils import get_rng
from ...utils.init_utils import normal, zeros
class MF(Recommender, ANNMixin):
"""Matrix Factorization.
......@@ -44,14 +35,26 @@ class MF(Recommender, ANNMixin):
k: int, optional, default: 10
The dimension of the latent factors.
backend: str, optional, default: 'cpu'
Backend used for model training: cpu, pytorch
optimizer: str, optional, default: 'sgd'
Specify an optimizer: adagrad, adam, rmsprop, sgd. (ineffective if using CPU backend)
max_iter: int, optional, default: 100
Maximum number of iterations or the number of epochs for SGD.
Maximum number of iterations or the number of epochs for training.
learning_rate: float, optional, default: 0.01
The learning rate.
batch_size: int, optional, default: 256
Batch size (ineffective if using CPU backend).
lambda_reg: float, optional, default: 0.001
The lambda value used for regularization.
dropout: float, optional, default: 0.0
The dropout rate of embedding. (ineffective if using CPU backend)
use_bias: boolean, optional, default: True
When True, user, item, and global biases are used.
......@@ -61,7 +64,8 @@ class MF(Recommender, ANNMixin):
num_threads: int, optional, default: 0
Number of parallel threads for training. If num_threads=0, all CPU cores will be utilized.
If seed is not None, num_threads=1 to remove randomness from parallelization.
If seed is not None, num_threads=1 to remove randomness from parallelization.
(Only effective if using CPU backend).
trainable: boolean, optional, default: True
When False, the model will not be re-trained, and input of pre-trained parameters are required.
......@@ -84,25 +88,33 @@ class MF(Recommender, ANNMixin):
"""
def __init__(
self,
name='MF',
k=10,
max_iter=20,
learning_rate=0.01,
lambda_reg=0.02,
self,
name="MF",
k=10,
backend="cpu",
optimizer="sgd",
max_iter=20,
learning_rate=0.01,
batch_size=256,
lambda_reg=0.02,
dropout=0.0,
use_bias=True,
early_stop=False,
num_threads=0,
trainable=True,
verbose=False,
init_params=None,
seed=None
early_stop=False,
num_threads=0,
trainable=True,
verbose=False,
init_params=None,
seed=None,
):
super().__init__(name=name, trainable=trainable, verbose=verbose)
self.k = k
self.backend = backend
self.optimizer = optimizer
self.max_iter = max_iter
self.learning_rate = learning_rate
self.batch_size = batch_size
self.lambda_reg = lambda_reg
self.dropout = dropout
self.use_bias = use_bias
self.early_stop = early_stop
self.seed = seed
......@@ -116,21 +128,29 @@ class MF(Recommender, ANNMixin):
# Init params if provided
self.init_params = {} if init_params is None else init_params
self.u_factors = self.init_params.get('U', None)
self.i_factors = self.init_params.get('V', None)
self.u_biases = self.init_params.get('Bu', None)
self.i_biases = self.init_params.get('Bi', None)
self.u_factors = self.init_params.get("U", None)
self.i_factors = self.init_params.get("V", None)
self.u_biases = self.init_params.get("Bu", None)
self.i_biases = self.init_params.get("Bi", None)
def _init(self):
rng = get_rng(self.seed)
if self.u_factors is None:
self.u_factors = normal([self.num_users, self.k], std=0.01, random_state=rng)
self.u_factors = normal(
[self.num_users, self.k], std=0.01, random_state=rng
)
if self.i_factors is None:
self.i_factors = normal([self.num_items, self.k], std=0.01, random_state=rng)
self.i_factors = normal(
[self.num_items, self.k], std=0.01, random_state=rng
)
self.u_biases = zeros(self.num_users) if self.u_biases is None else self.u_biases
self.i_biases = zeros(self.num_items) if self.i_biases is None else self.i_biases
self.u_biases = (
zeros(self.num_users) if self.u_biases is None else self.u_biases
)
self.i_biases = (
zeros(self.num_items) if self.i_biases is None else self.i_biases
)
self.global_mean = self.global_mean if self.use_bias else 0.0
def fit(self, train_set, val_set=None):
......@@ -153,84 +173,83 @@ class MF(Recommender, ANNMixin):
self._init()
if self.trainable:
(rid, cid, val) = train_set.uir_tuple
self._fit_sgd(rid, cid, val.astype(np.float32),
self.u_factors, self.i_factors,
self.u_biases, self.i_biases)
if self.backend == "cpu":
self._fit_cpu(train_set, val_set)
elif self.backend == "pytorch":
self._fit_pt(train_set, val_set)
else:
raise ValueError(f"{self.backend} is not supported")
return self
@cython.boundscheck(False)
@cython.wraparound(False)
def _fit_sgd(self, integral[:] rid, integral[:] cid, floating[:] val,
floating[:, :] U, floating[:, :] V, floating[:] Bu, floating[:] Bi):
"""Fit the model parameters (U, V, Bu, Bi) with SGD"""
cdef:
long num_users = self.num_users
long num_items = self.num_items
long num_ratings = val.shape[0]
int num_factors = self.k
int max_iter = self.max_iter
int num_threads = self.num_threads
floating reg = self.lambda_reg
floating mu = self.global_mean
bool use_bias = self.use_bias
bool early_stop = self.early_stop
bool verbose = self.verbose
floating lr = self.learning_rate
floating loss = 0
floating last_loss = 0
floating r, r_pred, error, u_f, i_f, delta_loss
integral u, i, f, j
floating * user
floating * item
progress = trange(max_iter, disable=not self.verbose)
for epoch in progress:
last_loss = loss
loss = 0
for j in prange(num_ratings, nogil=True, schedule='static', num_threads=num_threads):
u, i, r = rid[j], cid[j], val[j]
user, item = &U[u, 0], &V[i, 0]
# predict rating
r_pred = mu + Bu[u] + Bi[i]
for f in range(num_factors):
r_pred = r_pred + user[f] * item[f]
error = r - r_pred
loss += error * error
# update factors
for f in range(num_factors):
u_f, i_f = user[f], item[f]
user[f] += lr * (error * i_f - reg * u_f)
item[f] += lr * (error * u_f - reg * i_f)
# update biases
if use_bias:
Bu[u] += lr * (error - reg * Bu[u])
Bi[i] += lr * (error - reg * Bi[i])
loss = 0.5 * loss
progress.update(1)
progress.set_postfix({"loss": "%.2f" % loss})
delta_loss = loss - last_loss
if early_stop and abs(delta_loss) < 1e-5:
if verbose:
print('Early stopping, delta_loss = %.4f' % delta_loss)
break
progress.close()
if verbose:
print('Optimization finished!')
#################
## CPU backend ##
#################
def _fit_cpu(self, train_set, val_set):
from cornac.models.mf import backend_cpu
(rid, cid, val) = train_set.uir_tuple
backend_cpu.fit_sgd(
rid,
cid,
val.astype(np.float32),
self.u_factors,
self.i_factors,
self.u_biases,
self.i_biases,
self.num_users,
self.num_items,
self.learning_rate,
self.lambda_reg,
self.global_mean,
self.max_iter,
self.num_threads,
self.use_bias,
self.early_stop,
self.verbose,
)
#####################
## PyTorch backend ##
#####################
def _fit_pt(self, train_set, val_set):
import torch
from .backend_pt import MF, learn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
if self.seed is not None:
torch.manual_seed(self.seed)
np.random.seed(self.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.seed)
model = MF(
self.u_factors,
self.i_factors,
self.u_biases.reshape(-1, 1),
self.i_biases.reshape(-1, 1),
self.use_bias,
self.global_mean,
self.dropout,
)
learn(
model=model,
train_set=train_set,
n_epochs=self.max_iter,
batch_size=self.batch_size,
learning_rate=self.learning_rate,
reg=self.lambda_reg,
optimizer=self.optimizer,
device=device,
)
self.u_factors = model.u_factors.weight.detach().cpu().numpy()
self.i_factors = model.i_factors.weight.detach().cpu().numpy()
if self.use_bias:
self.u_biases = model.u_biases.weight.detach().cpu().squeeze().numpy()
self.i_biases = model.i_biases.weight.detach().cpu().squeeze().numpy()
def score(self, user_idx, item_idx=None):
"""Predict the scores/ratings of a user for an item.
......@@ -264,11 +283,19 @@ class MF(Recommender, ANNMixin):
if self.knows_item(item_idx):
item_score += self.i_biases[item_idx]
if self.knows_user(user_idx) and self.knows_item(item_idx):
item_score += np.dot(self.u_factors[user_idx], self.i_factors[item_idx])
item_score += np.dot(
self.u_factors[user_idx], self.i_factors[item_idx]
)
else:
if not self.knows_user(user_idx) or self.knows_item(item_idx):
raise ScoreException("Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx))
item_score = np.dot(self.u_factors[user_idx], self.i_factors[item_idx])
if self.knows_user(user_idx) and self.knows_item(item_idx):
item_score = np.dot(
self.u_factors[user_idx], self.i_factors[item_idx]
)
else:
raise ScoreException(
"Can't make score prediction for (user_id=%d, item_id=%d)"
% (user_idx, item_idx)
)
return item_score
def get_vector_measure(self):
......@@ -287,7 +314,7 @@ class MF(Recommender, ANNMixin):
Returns
-------
out: numpy.array
Matrix of user vectors for all users available in the model.
Matrix of user vectors for all users available in the model.
"""
user_vectors = self.u_factors
if self.use_bias:
......@@ -295,28 +322,28 @@ class MF(Recommender, ANNMixin):
(
user_vectors,
self.u_biases.reshape((-1, 1)),
np.ones([user_vectors.shape[0], 1]), # augmented for item bias
),
axis=1
np.ones([user_vectors.shape[0], 1]), # augmented for item bias
),
axis=1,
)
return user_vectors
def get_item_vectors(self):
"""Getting a matrix of item vectors used for building the index for ANN search.
Returns
-------
out: numpy.array
Matrix of item vectors for all items available in the model.
Matrix of item vectors for all items available in the model.
"""
item_vectors = self.i_factors
if self.use_bias:
item_vectors = np.concatenate(
(
item_vectors,
np.ones([item_vectors.shape[0], 1]), # augmented for user bias
np.ones([item_vectors.shape[0], 1]), # augmented for user bias
self.i_biases.reshape((-1, 1)),
),
axis=1
),
axis=1,
)
return item_vectors
\ No newline at end of file
return item_vectors
......@@ -18,7 +18,6 @@ import cornac
from cornac.datasets import movielens
from cornac.eval_methods import RatioSplit
# Load MovieLens 1M ratings
ml_1m = movielens.load_feedback(variant="1M")
......@@ -31,22 +30,42 @@ ratio_split = RatioSplit(
global_avg = cornac.models.GlobalAvg()
mf = cornac.models.MF(
k=10,
backend="cpu",
max_iter=25,
learning_rate=0.01,
lambda_reg=0.02,
use_bias=True,
early_stop=True,
verbose=True,
name="MF-cpu",
)
tmf = cornac.models.MF(
k=10,
backend="pytorch",
optimizer="sgd",
max_iter=25,
batch_size=256,
learning_rate=0.01,
lambda_reg=1e-2,
trainable=True,
verbose=True,
name="MF-pytorch",
)
# Instantiate MAE and RMSE for evaluation
mae = cornac.metrics.MAE()
rmse = cornac.metrics.RMSE()
ndcg = cornac.metrics.NDCG(k=10)
recall = cornac.metrics.Recall(k=10)
# Put everything together into an experiment and run it
cornac.Experiment(
eval_method=ratio_split,
models=[global_avg, mf],
metrics=[mae, rmse],
models=[
global_avg,
mf,
tmf,
],
metrics=[mae, rmse, ndcg, recall],
user_based=True,
).run()
......@@ -158,8 +158,8 @@ extensions = [
language="c++",
),
Extension(
name="cornac.models.mf.recom_mf",
sources=["cornac/models/mf/recom_mf" + ext],
name="cornac.models.mf.backend_cpu",
sources=["cornac/models/mf/backend_cpu" + ext],
include_dirs=[np.get_include()],
language="c++",
extra_compile_args=compile_args,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment