Skip to content
Snippets Groups Projects
Unverified Commit ef374a4e authored by Lê Trung Hoàng's avatar Lê Trung Hoàng Committed by GitHub
Browse files

Add Correlation-Sensitive Next-Basket Recommendation (Beacon) Model (#584)

* Add beacon model

* Add example

* Update docs

* Fixed data_iter

* refactor code
parent b845e888
No related branches found
No related tags found
No related merge requests found
...@@ -153,9 +153,10 @@ The recommender models supported by Cornac are listed below. Why don't you join ...@@ -153,9 +153,10 @@ The recommender models supported by Cornac are listed below. Why don't you join
| | [Hybrid neural recommendation with joint deep representation learning of ratings and reviews (HRDR)](cornac/models/hrdr), [paper](https://www.sciencedirect.com/science/article/abs/pii/S0925231219313207) | [requirements.txt](cornac/models/hrdr/requirements.txt) | [hrdr_example.py](examples/hrdr_example.py) | | [Hybrid neural recommendation with joint deep representation learning of ratings and reviews (HRDR)](cornac/models/hrdr), [paper](https://www.sciencedirect.com/science/article/abs/pii/S0925231219313207) | [requirements.txt](cornac/models/hrdr/requirements.txt) | [hrdr_example.py](examples/hrdr_example.py)
| | [LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation](cornac/models/lightgcn), [paper](https://arxiv.org/pdf/2002.02126.pdf) | [requirements.txt](cornac/models/lightgcn/requirements.txt) | [lightgcn_example.py](examples/lightgcn_example.py) | | [LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation](cornac/models/lightgcn), [paper](https://arxiv.org/pdf/2002.02126.pdf) | [requirements.txt](cornac/models/lightgcn/requirements.txt) | [lightgcn_example.py](examples/lightgcn_example.py)
| | [New Variational Autoencoder for Top-N Recommendations with Implicit Feedback (RecVAE)](cornac/models/recvae), [paper](https://doi.org/10.1145/3336191.3371831) | [requirements.txt](cornac/models/recvae/requirements.txt) | [recvae_example.py](examples/recvae_example.py) | | [New Variational Autoencoder for Top-N Recommendations with Implicit Feedback (RecVAE)](cornac/models/recvae), [paper](https://doi.org/10.1145/3336191.3371831) | [requirements.txt](cornac/models/recvae/requirements.txt) | [recvae_example.py](examples/recvae_example.py)
| | [Temporal-Item-Frequency-based User-KNN (TIFUKNN)](cornac/models/tifuknn), [paper](https://arxiv.org/pdf/2006.00556.pdf) | N/A | [tifuknn_tafeng.py](examples/tifuknn_tafeng.py)
| | [Recency Aware Collaborative Filtering for Next Basket Recommendation (UPCF)](cornac/models/upcf), [paper](https://dl.acm.org/doi/abs/10.1145/3340631.3394850) | [requirements.txt](cornac/models/upcf/requirements.txt) | [upcf_tafeng.py](examples/upcf_tafeng.py) | | [Recency Aware Collaborative Filtering for Next Basket Recommendation (UPCF)](cornac/models/upcf), [paper](https://dl.acm.org/doi/abs/10.1145/3340631.3394850) | [requirements.txt](cornac/models/upcf/requirements.txt) | [upcf_tafeng.py](examples/upcf_tafeng.py)
| 2019 | [Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)](cornac/models/ease), [paper](https://arxiv.org/pdf/1905.03375.pdf) | N/A | [ease_movielens.py](examples/ease_movielens.py) | | [Temporal-Item-Frequency-based User-KNN (TIFUKNN)](cornac/models/tifuknn), [paper](https://arxiv.org/pdf/2006.00556.pdf) | N/A | [tifuknn_tafeng.py](examples/tifuknn_tafeng.py)
| 2019 | [Correlation-Sensitive Next-Basket Recommendation (Beacon)](cornac/models/beacon), [paper](https://www.ijcai.org/proceedings/2019/0389.pdf) | [requirements.txt](cornac/models/beacon/requirements.txt) | [beacon_tafeng.py](examples/beacon_tafeng.py)
| | [Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)](cornac/models/ease), [paper](https://arxiv.org/pdf/1905.03375.pdf) | N/A | [ease_movielens.py](examples/ease_movielens.py)
| | [Neural Graph Collaborative Filtering (NGCF)](cornac/models/ngcf), [paper](https://arxiv.org/pdf/1905.08108.pdf) | [requirements.txt](cornac/models/ngcf/requirements.txt) | [ngcf_example.py](examples/ngcf_example.py) | | [Neural Graph Collaborative Filtering (NGCF)](cornac/models/ngcf), [paper](https://arxiv.org/pdf/1905.08108.pdf) | [requirements.txt](cornac/models/ngcf/requirements.txt) | [ngcf_example.py](examples/ngcf_example.py)
| 2018 | [Collaborative Context Poisson Factorization (C2PF)](cornac/models/c2pf), [paper](https://www.ijcai.org/proceedings/2018/0370.pdf) | N/A | [c2pf_exp.py](examples/c2pf_example.py) | 2018 | [Collaborative Context Poisson Factorization (C2PF)](cornac/models/c2pf), [paper](https://www.ijcai.org/proceedings/2018/0370.pdf) | N/A | [c2pf_exp.py](examples/c2pf_example.py)
| | [Graph Convolutional Matrix Completion (GCMC)](cornac/models/gcmc), [paper](https://www.kdd.org/kdd2018/files/deep-learning-day/DLDay18_paper_32.pdf) | [requirements.txt](cornac/models/gcmc/requirements.txt) | [gcmc_example.py](examples/gcmc_example.py) | | [Graph Convolutional Matrix Completion (GCMC)](cornac/models/gcmc), [paper](https://www.kdd.org/kdd2018/files/deep-learning-day/DLDay18_paper_32.pdf) | [requirements.txt](cornac/models/gcmc/requirements.txt) | [gcmc_example.py](examples/gcmc_example.py)
......
...@@ -23,6 +23,7 @@ from .ann import FaissANN ...@@ -23,6 +23,7 @@ from .ann import FaissANN
from .ann import HNSWLibANN from .ann import HNSWLibANN
from .ann import ScaNNANN from .ann import ScaNNANN
from .baseline_only import BaselineOnly from .baseline_only import BaselineOnly
from .beacon import Beacon
from .bivaecf import BiVAECF from .bivaecf import BiVAECF
from .bpr import BPR from .bpr import BPR
from .bpr import WBPR from .bpr import WBPR
......
# Copyright 2023 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from .recom_beacon import Beacon
import numpy as np
import warnings
# disable annoying tensorflow deprecated API warnings
warnings.filterwarnings("ignore", category=UserWarning)
import tensorflow.compat.v1 as tf
tf.logging.set_verbosity(tf.logging.ERROR)
tf.disable_v2_behavior()
def create_rnn_cell(cell_type, state_size, default_initializer, reuse=None):
if cell_type == "GRU":
return tf.nn.rnn_cell.GRUCell(state_size, activation=tf.nn.tanh, reuse=reuse)
elif cell_type == "LSTM":
return tf.nn.rnn_cell.LSTMCell(
state_size,
initializer=default_initializer,
activation=tf.nn.tanh,
reuse=reuse,
)
else:
return tf.nn.rnn_cell.BasicRNNCell(
state_size, activation=tf.nn.tanh, reuse=reuse
)
def create_rnn_encoder(
x,
rnn_units,
dropout_rate,
seq_length,
rnn_cell_type,
param_initializer,
seed,
reuse=None,
):
with tf.variable_scope("RNN_Encoder", reuse=reuse):
rnn_cell = create_rnn_cell(rnn_cell_type, rnn_units, param_initializer)
rnn_cell = tf.nn.rnn_cell.DropoutWrapper(
rnn_cell, input_keep_prob=1 - dropout_rate, seed=seed
)
init_state = rnn_cell.zero_state(tf.shape(x)[0], tf.float32)
# RNN Encoder: Iteratively compute output of recurrent network
rnn_outputs, _ = tf.nn.dynamic_rnn(
rnn_cell,
x,
initial_state=init_state,
sequence_length=seq_length,
dtype=tf.float32,
)
return rnn_outputs
def create_basket_encoder(
x,
dense_units,
param_initializer,
activation_func=None,
name="Basket_Encoder",
reuse=None,
):
with tf.variable_scope(name, reuse=reuse):
return tf.layers.dense(
x,
dense_units,
kernel_initializer=param_initializer,
bias_initializer=tf.zeros_initializer,
activation=activation_func,
)
def get_last_right_output(full_output, max_length, actual_length, rnn_units):
batch_size = tf.shape(full_output)[0]
# Start indices for each sample
index = tf.range(0, batch_size) * max_length + (actual_length - 1)
# Indexing
return tf.gather(tf.reshape(full_output, [-1, rnn_units]), index)
class BeaconModel:
def __init__(
self,
sess,
emb_dim,
rnn_units,
alpha,
max_seq_length,
n_items,
item_probs,
adj_matrix,
rnn_cell_type,
rnn_dropout_rate,
seed,
lr,
):
self.scope = "GRN"
self.session = sess
self.seed = seed
self.lr = tf.constant(lr)
self.emb_dim = emb_dim
self.rnn_units = rnn_units
self.max_seq_length = max_seq_length
self.n_items = n_items
self.item_probs = item_probs
self.alpha = alpha
with tf.variable_scope(self.scope):
# Initialized for n_hop adjacency matrix
self.A = tf.constant(
adj_matrix.todense(), name="Adj_Matrix", dtype=tf.float32
)
uniform_initializer = (
np.ones(shape=(self.n_items), dtype=np.float32) / self.n_items
)
self.I_B = tf.get_variable(
dtype=tf.float32,
initializer=tf.constant(uniform_initializer, dtype=tf.float32),
name="I_B",
)
self.I_B_Diag = tf.nn.relu(tf.diag(self.I_B, name="I_B_Diag"))
self.C_Basket = tf.get_variable(
dtype=tf.float32, initializer=tf.constant(adj_matrix.mean()), name="C_B"
)
self.y = tf.placeholder(
dtype=tf.float32,
shape=(None, self.n_items),
name="Target_basket",
)
# Basket Sequence encoder
with tf.name_scope("Basket_Sequence_Encoder"):
self.bseq = tf.sparse.placeholder(
dtype=tf.float32,
name="bseq_input",
)
self.bseq_length = tf.placeholder(
dtype=tf.int32, shape=(None,), name="bseq_length"
)
self.bseq_encoder = tf.sparse.reshape(
self.bseq, shape=[-1, self.n_items], name="bseq_2d"
)
self.bseq_encoder = self.encode_basket_graph(
self.bseq_encoder, self.C_Basket, True
)
self.bseq_encoder = tf.reshape(
self.bseq_encoder,
shape=[-1, self.max_seq_length, self.n_items],
name="bsxMxN",
)
self.bseq_encoder = create_basket_encoder(
self.bseq_encoder,
emb_dim,
param_initializer=tf.initializers.he_uniform(),
activation_func=tf.nn.relu,
)
# batch_size x max_seq_length x H
rnn_encoder = create_rnn_encoder(
self.bseq_encoder,
self.rnn_units,
rnn_dropout_rate,
self.bseq_length,
rnn_cell_type,
param_initializer=tf.initializers.glorot_uniform(),
seed=self.seed,
)
# Hack to build the indexing and retrieve the right output. # batch_size x H
h_T = get_last_right_output(
rnn_encoder, self.max_seq_length, self.bseq_length, self.rnn_units
)
# Next basket estimation
with tf.name_scope("Next_Basket"):
W_H = tf.get_variable(
dtype=tf.float32,
initializer=tf.initializers.glorot_uniform(),
shape=(self.rnn_units, self.n_items),
name="W_H",
)
next_item_probs = tf.nn.sigmoid(tf.matmul(h_T, W_H))
logits = (
1.0 - self.alpha
) * next_item_probs + self.alpha * self.encode_basket_graph(
next_item_probs, tf.constant(0.0)
)
with tf.name_scope("Loss"):
self.loss = self.compute_loss(logits, self.y)
self.predictions = tf.nn.sigmoid(logits)
# Adam optimizer
train_op = tf.train.RMSPropOptimizer(learning_rate=self.lr)
# Op to calculate every variable gradient
self.grads = train_op.compute_gradients(self.loss, tf.trainable_variables())
self.update_grads = train_op.apply_gradients(self.grads)
def train_batch(self, s, s_length, y):
bseq_indices, bseq_values, bseq_shape = self.get_sparse_tensor_info(s, True)
[_, loss] = self.session.run(
[self.update_grads, self.loss],
feed_dict={
self.bseq: (bseq_indices, bseq_values, bseq_shape),
self.bseq_length: s_length,
self.y: y,
},
)
return loss
def validate_batch(self, s, s_length, y):
bseq_indices, bseq_values, bseq_shape = self.get_sparse_tensor_info(s, True)
loss = self.session.run(
self.loss,
feed_dict={
self.bseq: (bseq_indices, bseq_values, bseq_shape),
self.bseq_length: s_length,
self.y: y,
},
)
return loss
def predict(self, s, s_length):
bseq_indices, bseq_values, bseq_shape = self.get_sparse_tensor_info(s, True)
predictions = self.session.run(
self.predictions,
feed_dict={
self.bseq: (bseq_indices, bseq_values, bseq_shape),
self.bseq_length: s_length,
},
)
return predictions.squeeze()
def encode_basket_graph(self, binput, beta, is_sparse=False):
with tf.name_scope("Graph_Encoder"):
if is_sparse:
encoder = tf.sparse_tensor_dense_matmul(
binput, self.I_B_Diag, name="XxI_B"
)
encoder += self.relu_with_threshold(
tf.sparse_tensor_dense_matmul(binput, self.A, name="XxA"), beta
)
else:
encoder = tf.matmul(binput, self.I_B_Diag, name="XxI_B")
encoder += self.relu_with_threshold(
tf.matmul(binput, self.A, name="XxA"), beta
)
return encoder
def get_sparse_tensor_info(self, x, is_bseq=False):
indices = []
if is_bseq:
for sid, bseq in enumerate(x):
for t, basket in enumerate(bseq):
for item_id in basket:
indices.append([sid, t, item_id])
else:
for bid, basket in enumerate(x):
for item_id in basket:
indices.append([bid, item_id])
values = np.ones(len(indices), dtype=np.float32)
indices = np.array(indices, dtype=np.int32)
shape = np.array([len(x), self.max_seq_length, self.n_items], dtype=np.int64)
return indices, values, shape
def compute_loss(self, logits, y):
sigmoid_logits = tf.nn.sigmoid(logits)
neg_y = 1.0 - y
pos_logits = y * logits
pos_max = tf.reduce_max(pos_logits, axis=1)
pos_max = tf.expand_dims(pos_max, axis=-1)
pos_min = tf.reduce_min(pos_logits + neg_y * pos_max, axis=1)
pos_min = tf.expand_dims(pos_min, axis=-1)
nb_pos, nb_neg = tf.count_nonzero(y, axis=1), tf.count_nonzero(neg_y, axis=1)
ratio = tf.cast(nb_neg, dtype=tf.float32) / tf.cast(nb_pos, dtype=tf.float32)
pos_weight = tf.expand_dims(ratio, axis=-1)
loss = y * -tf.log(sigmoid_logits) * pos_weight + neg_y * -tf.log(
1.0 - tf.nn.sigmoid(logits - pos_min)
)
return tf.reduce_mean(loss + 1e-8)
def relu_with_threshold(self, x, threshold):
return tf.nn.relu(x - tf.abs(threshold))
# Copyright 2023 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
from collections import Counter
import numpy as np
from scipy.sparse import csc_matrix, csr_matrix, diags
from tqdm.auto import trange
from ..recommender import NextBasketRecommender
class Beacon(NextBasketRecommender):
"""Correlation-Sensitive Next-Basket Recommendation
Parameters
----------
name: string, default: 'Beacon'
The name of the recommender model.
emb_dim: int, optional, default: 2
Embedding dimension
rnn_unit: int, optional, default: 4
Number of dimension in a rnn unit.
alpha: float, optional, default: 0.5
Hyperparameter to control the balance between correlative and sequential associations.
rnn_cell_type: str, optional, default: 'LSTM'
RNN cell type, options including ['LSTM', 'GRU', None]
If None, BasicRNNCell will be used.
dropout_rate: float, optional, default: 0.5
Dropout rate of neural network dense layers
nb_hop: int, optional, default: 1
Number of hops for constructing correlation matrix.
If 0, zeros matrix will be used.
n_epochs: int, optional, default: 15
Number of training epochs
batch_size: int, optional, default: 32
Batch size
lr: float, optional, default: 0.001
Initial value of learning rate for the optimizer.
verbose: boolean, optional, default: False
When True, running logs are displayed.
seed: int, optional, default: None
Random seed
References
----------
LE, Duc Trong, Hady Wirawan LAUW, and Yuan Fang.
Correlation-sensitive next-basket recommendation.
International Joint Conferences on Artificial Intelligence, 2019.
"""
def __init__(
self,
name="Beacon",
emb_dim=2,
rnn_unit=4,
alpha=0.5,
rnn_cell_type="LSTM",
dropout_rate=0.5,
nb_hop=1,
n_epochs=15,
batch_size=32,
lr=0.001,
trainable=True,
verbose=False,
seed=None,
):
super().__init__(name=name, trainable=trainable, verbose=verbose)
self.n_epochs = n_epochs
self.batch_size = batch_size
self.nb_hop = nb_hop
self.emb_dim = emb_dim
self.rnn_unit = rnn_unit
self.alpha = alpha
self.rnn_cell_type = rnn_cell_type
self.dropout_rate = dropout_rate
self.seed = seed
self.lr = lr
def fit(self, train_set, val_set=None):
import tensorflow.compat.v1 as tf
from .beacon_tf import BeaconModel
tf.disable_eager_execution()
# less verbose TF
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.logging.set_verbosity(tf.logging.ERROR)
super().fit(train_set=train_set, val_set=val_set)
self.correlation_matrix = self._build_correlation_matrix(
train_set=train_set, val_set=val_set, n_items=self.total_items
)
self.item_probs = self._compute_item_probs(
train_set=train_set, val_set=val_set, n_items=self.total_items
)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.log_device_placement = False
sess = tf.Session(config=config)
self.model = BeaconModel(
sess,
self.emb_dim,
self.rnn_unit,
self.alpha,
train_set.max_basket_size,
self.total_items,
self.item_probs,
self.correlation_matrix,
self.rnn_cell_type,
self.dropout_rate,
self.seed,
self.lr,
)
sess.run(tf.global_variables_initializer()) # init variable
last_loss = np.inf
last_val_loss = np.inf
loop = trange(self.n_epochs, disable=not self.verbose)
loop.set_postfix(
loss=last_loss,
val_loss=last_val_loss,
)
train_pool = []
validation_pool = []
for _ in loop:
train_loss = 0.0
trained_cnt = 0
for batch_basket_items in self._data_iter(
train_set, shuffle=True, current_pool=train_pool
):
s, s_length, y = self._transform_data(
batch_basket_items, self.total_items
)
loss = self.model.train_batch(s, s_length, y)
current_batch_size = len(batch_basket_items)
trained_cnt += current_batch_size
train_loss += loss * current_batch_size
last_loss = train_loss / trained_cnt
loop.set_postfix(
loss=last_loss,
val_loss=last_val_loss,
)
if val_set is not None:
val_loss = 0.0
val_cnt = 0
for batch_basket_items in self._data_iter(
val_set, shuffle=False, current_pool=validation_pool
):
s, s_length, y = self._transform_data(
batch_basket_items, self.total_items
)
loss = self.model.validate_batch(s, s_length, y)
current_batch_size = len(batch_basket_items)
val_cnt += current_batch_size
val_loss += loss * current_batch_size
last_val_loss = val_loss / val_cnt
loop.set_postfix(
loss=last_loss,
val_loss=last_val_loss,
)
return self
def _data_iter(self, data_set, shuffle=False, current_pool=[]):
"""This iterator ensure each batch has same size, the remaining data will be preceded in the next epoch"""
for _, _, batch_basket_items in data_set.ubi_iter(
batch_size=self.batch_size, shuffle=shuffle
):
current_pool += batch_basket_items
if len(current_pool) >= self.batch_size:
yield current_pool[: self.batch_size]
del current_pool[self.batch_size :]
def _transform_data(self, batch_basket_items, n_items):
assert len(batch_basket_items) == self.batch_size
s = [basket_items[:-1] for basket_items in batch_basket_items]
s_length = [len(b) for b in s]
y = np.zeros((self.batch_size, n_items), dtype="int32")
for inc, basket_items in enumerate(batch_basket_items):
y[inc, basket_items[-1]] = 1
return s, s_length, y
def _build_correlation_matrix(self, train_set, val_set, n_items):
if self.nb_hop == 0:
return csr_matrix((n_items, n_items), dtype="float32")
pairs_cnt = Counter()
for _, _, [basket_items] in train_set.ubi_iter(1, shuffle=False):
for items in basket_items:
current_items = np.unique(items)
for i in range(len(current_items) - 1):
for j in range(i + 1, len(current_items)):
pairs_cnt[(current_items[i], current_items[j])] += 1
if val_set is not None:
for _, _, [basket_items] in val_set.ubi_iter(1, shuffle=False):
for items in basket_items:
current_items = np.unique(items)
for i in range(len(current_items) - 1):
for j in range(i + 1, len(current_items)):
pairs_cnt[(current_items[i], current_items[j])] += 1
data, row, col = [], [], []
for pair, cnt in pairs_cnt.most_common():
data.append(cnt)
row.append(pair[0])
col.append(pair[1])
correlation_matrix = csc_matrix(
(data, (row, col)), shape=(n_items, n_items), dtype="float32"
)
correlation_matrix = self._normalize(correlation_matrix)
w_mul = correlation_matrix
coeff = 1.0
for _ in range(1, self.nb_hop):
coeff *= 0.85
w_mul *= correlation_matrix
w_mul = self._remove_diag(w_mul)
w_adj_matrix = self._normalize(w_mul)
correlation_matrix += coeff * w_adj_matrix
return correlation_matrix
def _remove_diag(self, adj_matrix):
new_adj_matrix = csr_matrix(adj_matrix)
new_adj_matrix.setdiag(0.0)
new_adj_matrix.eliminate_zeros()
return new_adj_matrix
def _normalize(self, adj_matrix: csr_matrix):
"""Symmetrically normalize adjacency matrix."""
row_sum = adj_matrix.sum(1).A.squeeze()
d_inv_sqrt = np.power(
row_sum,
-0.5,
out=np.zeros_like(row_sum, dtype="float32"),
where=row_sum != 0,
)
d_mat_inv_sqrt = diags(d_inv_sqrt)
normalized_matrix = (
adj_matrix.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt)
)
return normalized_matrix.tocsr()
def _compute_item_probs(self, train_set, val_set, n_items):
item_freq = Counter(train_set.uir_tuple[1])
if val_set is not None:
item_freq += Counter(val_set.uir_tuple[1])
item_probs = np.zeros(n_items, dtype="float32")
total_cnt = len(train_set.uir_tuple[1]) + len(val_set.uir_tuple[1])
for iid, cnt in item_freq.items():
item_probs[iid] = cnt / total_cnt
return item_probs
def score(self, user_idx, history_baskets, **kwargs):
s = [history_baskets]
s_length = [len(history_baskets)]
return self.model.predict(s, s_length)
tensorflow[and-cuda]==2.15.0
...@@ -54,6 +54,11 @@ Temporal-Item-Frequency-based User-KNN (TIFUKNN) ...@@ -54,6 +54,11 @@ Temporal-Item-Frequency-based User-KNN (TIFUKNN)
.. automodule:: cornac.models.tifuknn.recom_tifuknn .. automodule:: cornac.models.tifuknn.recom_tifuknn
:members: :members:
Correlation-Sensitive Next-Basket Recommendation (Beacon)
---------------------------------------------------
.. automodule:: cornac.models.beacon.recom_beacon
:members:
Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ) Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)
----------------------------------------------------------- -----------------------------------------------------------
.. automodule:: cornac.models.ease.recom_ease .. automodule:: cornac.models.ease.recom_ease
......
...@@ -120,6 +120,8 @@ ...@@ -120,6 +120,8 @@
[gp_top_tafeng.py](gp_top_tafeng.py) - Next-basket recommendation model that merely uses item top frequency. [gp_top_tafeng.py](gp_top_tafeng.py) - Next-basket recommendation model that merely uses item top frequency.
[beacon_tafeng.py](beacon_tafeng.py) - Correlation-Sensitive Next-Basket Recommendation (Beacon).
[tifuknn_tafeng.py](tifuknn_tafeng.py) - Example of Temporal-Item-Frequency-based User-KNN (TIFUKNN). [tifuknn_tafeng.py](tifuknn_tafeng.py) - Example of Temporal-Item-Frequency-based User-KNN (TIFUKNN).
[upcf_tafeng.py](upcf_tafeng.py) - Example of Recency Aware Collaborative Filtering for Next Basket Recommendation (UPCF). [upcf_tafeng.py](upcf_tafeng.py) - Example of Recency Aware Collaborative Filtering for Next Basket Recommendation (UPCF).
# Copyright 2023 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Example of Correlation-Sensitive Next-Basket Recommendation Model (Beacon)"""
import cornac
from cornac.eval_methods import NextBasketEvaluation
from cornac.metrics import NDCG, HitRatio, Recall
from cornac.models import Beacon
data = cornac.datasets.tafeng.load_basket(
reader=cornac.data.Reader(
min_basket_size=3, max_basket_size=50, min_basket_sequence=2
)
)
next_basket_eval = NextBasketEvaluation(
data=data, fmt="UBITJson", test_size=0.2, val_size=0.08, seed=123, verbose=True
)
models = [
Beacon(
emb_dim=2,
rnn_unit=4,
alpha=0.5,
rnn_cell_type="LSTM",
dropout_rate=0.5,
nb_hop=1,
n_epochs=15,
batch_size=32,
lr=0.001,
verbose=True,
)
]
metrics = [
Recall(k=10),
Recall(k=50),
NDCG(k=10),
NDCG(k=50),
HitRatio(k=10),
HitRatio(k=50),
]
cornac.Experiment(eval_method=next_basket_eval, models=models, metrics=metrics).run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment