From 2be4463bcc6dcc07a4d8edf7a01f236968abe344 Mon Sep 17 00:00:00 2001 From: Qrh <17882988+RuihongQiu@users.noreply.github.com> Date: Fri, 9 Jul 2021 22:00:16 +1000 Subject: [PATCH] Add AMR model (#420) --- README.md | 1 + cornac/models/__init__.py | 1 + cornac/models/amr/__init__.py | 16 ++ cornac/models/amr/recom_amr.py | 298 +++++++++++++++++++++++++++++ cornac/models/amr/requirements.txt | 1 + docs/source/models.rst | 5 + examples/amr_clothing.py | 62 ++++++ examples/causalrec_clothing.py | 5 +- 8 files changed, 387 insertions(+), 2 deletions(-) create mode 100644 cornac/models/amr/__init__.py create mode 100644 cornac/models/amr/recom_amr.py create mode 100644 cornac/models/amr/requirements.txt create mode 100644 examples/amr_clothing.py diff --git a/README.md b/README.md index 9ce625cf..e939d364 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,7 @@ The recommender models supported by Cornac are listed below. Why don't you join | 2021 | [Bilateral Variational Autoencoder for Collaborative Filtering (BiVAECF)](cornac/models/bivaecf), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441759) | [requirements.txt](cornac/models/bivaecf/requirements.txt) | [PreferredAI/bi-vae](https://github.com/PreferredAI/bi-vae) | | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [paper](https://arxiv.org/abs/2107.02390) | [requirements.txt](cornac/models/causalrec/requirements.txt) | [causalrec_clothing.py](examples/causalrec_clothing.py) | | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | N/A | [PreferredAI/ComparER](https://github.com/PreferredAI/ComparER) +| 2020 | [Adversarial Training Towards Robust Multimedia Recommender System (AMR)](cornac/models/amr), [paper](https://ieeexplore.ieee.org/document/8618394) | [requirements.txt](cornac/models/amr/requirements.txt) | [amr_clothing.py](examples/amr_clothing.py) | 2018 | [Collaborative Context Poisson Factorization (C2PF)](cornac/models/c2pf), [paper](https://www.ijcai.org/proceedings/2018/0370.pdf) | N/A | [c2pf_exp.py](examples/c2pf_example.py) | | [Multi-Task Explainable Recommendation (MTER)](cornac/models/mter), [paper](https://arxiv.org/pdf/1806.03568.pdf) | N/A | [mter_exp.py](examples/mter_example.py) | | [Neural Attention Rating Regression with Review-level Explanations (NARRE)](cornac/models/narre), [paper](http://www.thuir.cn/group/~YQLiu/publications/WWW2018_CC.pdf) | [requirements.txt](cornac/models/narre/requirements.txt) | [narre_example.py](examples/narre_example.py) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index d032090a..85dcf275 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -15,6 +15,7 @@ from .recommender import Recommender +from .amr import AMR from .baseline_only import BaselineOnly from .bivaecf import BiVAECF from .bpr import BPR diff --git a/cornac/models/amr/__init__.py b/cornac/models/amr/__init__.py new file mode 100644 index 00000000..14c15d2f --- /dev/null +++ b/cornac/models/amr/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from .recom_amr import AMR \ No newline at end of file diff --git a/cornac/models/amr/recom_amr.py b/cornac/models/amr/recom_amr.py new file mode 100644 index 00000000..1ef773df --- /dev/null +++ b/cornac/models/amr/recom_amr.py @@ -0,0 +1,298 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from tqdm.auto import tqdm + +from ..recommender import Recommender +from ...exception import CornacException +from ...exception import ScoreException +from ...utils import fast_dot +from ...utils.common import intersects +from ...utils import get_rng +from ...utils.init_utils import zeros, xavier_uniform + + +class AMR(Recommender): + """Adversarial Training Towards Robust Multimedia Recommender System. + + Parameters + ---------- + k: int, optional, default: 10 + The dimension of the gamma latent factors. + + k2: int, optional, default: 10 + The dimension of the theta latent factors. + + n_epochs: int, optional, default: 20 + Maximum number of epochs for SGD. + + batch_size: int, optional, default: 100 + The batch size for SGD. + + learning_rate: float, optional, default: 0.001 + The learning rate for SGD. + + lambda_w: float, optional, default: 0.01 + The regularization hyper-parameter for latent factor weights. + + lambda_b: float, optional, default: 0.01 + The regularization hyper-parameter for biases. + + lambda_e: float, optional, default: 0.0 + The regularization hyper-parameter for embedding matrix E and beta prime vector. + + lambda_adv: float, optional, default: 1.0 + The regularization hyper-parameter in Eq. (8) and (10) for the adversarial sample loss. + + use_gpu: boolean, optional, default: True + Whether or not to use GPU to speed up training. + + trainable: boolean, optional, default: True + When False, the model is not trained and Cornac assumes that the model already \ + pre-trained (U and V are not None). + + verbose: boolean, optional, default: True + When True, running logs are displayed. + + init_params: dictionary, optional, default: None + Initial parameters, e.g., init_params = {'Bi': beta_item, 'Gu': gamma_user, + 'Gi': gamma_item, 'Tu': theta_user, 'E': emb_matrix, 'Bp': beta_prime} + + seed: int, optional, default: None + Random seed for weight initialization. + + References + ---------- + * Tang, J., Du, X., He, X., Yuan, F., Tian, Q., and Chua, T. (2020). Adversarial Training Towards Robust Multimedia Recommender System. + """ + + def __init__( + self, + name="AMR", + k=10, + k2=10, + n_epochs=50, + batch_size=100, + learning_rate=0.005, + lambda_w=0.01, + lambda_b=0.01, + lambda_e=0.0, + lambda_adv=1.0, + use_gpu=False, + trainable=True, + verbose=True, + init_params=None, + seed=None, + ): + super().__init__(name=name, trainable=trainable, verbose=verbose) + self.k = k + self.k2 = k2 + self.n_epochs = n_epochs + self.batch_size = batch_size + self.learning_rate = learning_rate + self.lambda_w = lambda_w + self.lambda_b = lambda_b + self.lambda_e = lambda_e + self.lambda_adv = lambda_adv + self.use_gpu = use_gpu + self.seed = seed + + # Init params if provided + self.init_params = {} if init_params is None else init_params + self.gamma_user = self.init_params.get("Gu", None) + self.gamma_item = self.init_params.get("Gi", None) + self.emb_matrix = self.init_params.get("E", None) + + def _init(self, n_users, n_items, features): + rng = get_rng(self.seed) + + if self.gamma_user is None: + self.gamma_user = xavier_uniform((n_users, self.k), rng) + if self.gamma_item is None: + self.gamma_item = xavier_uniform((n_items, self.k), rng) + if self.emb_matrix is None: + self.emb_matrix = xavier_uniform((features.shape[1], self.k), rng) + + # pre-computed for faster evaluation + self.theta_item = np.matmul(features, self.emb_matrix) + + def fit(self, train_set, val_set=None): + """Fit the model to observations. + + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + + Returns + ------- + self : object + """ + Recommender.fit(self, train_set, val_set) + + if train_set.item_image is None: + raise CornacException("item_image modality is required but None.") + + # Item visual feature from CNN + train_features = train_set.item_image.features[: self.train_set.total_items] + train_features = train_features.astype(np.float32) + self._init( + n_users=train_set.total_users, + n_items=train_set.total_items, + features=train_features, + ) + + if self.trainable: + self._fit_torch(train_features) + + return self + + def _fit_torch(self, train_features): + import torch + + def _l2_loss(*tensors): + l2_loss = 0 + for tensor in tensors: + l2_loss += tensor.pow(2).sum() + return l2_loss / 2 + + def _inner(a, b): + return (a * b).sum(dim=1) + + dtype = torch.float + device = ( + torch.device("cuda:0") + if (self.use_gpu and torch.cuda.is_available()) + else torch.device("cpu") + ) + + # set requireds_grad=True to get the adversarial gradient + # if F is not put into the optimization list of parameters + # it won't be updated + F = torch.tensor( + train_features, device=device, dtype=dtype, requires_grad=True + ) + # Learned parameters + Gu = torch.tensor( + self.gamma_user, device=device, dtype=dtype, requires_grad=True + ) + Gi = torch.tensor( + self.gamma_item, device=device, dtype=dtype, requires_grad=True + ) + E = torch.tensor( + self.emb_matrix, device=device, dtype=dtype, requires_grad=True + ) + + optimizer = torch.optim.Adam([Gu, Gi, E], lr=self.learning_rate) + + for epoch in range(1, self.n_epochs + 1): + sum_loss = 0.0 + count = 0 + progress_bar = tqdm( + total=self.train_set.num_batches(self.batch_size), + desc="Epoch {}/{}".format(epoch, self.n_epochs), + disable=not self.verbose, + ) + for batch_u, batch_i, batch_j in self.train_set.uij_iter( + self.batch_size, shuffle=True + ): + gamma_u = Gu[batch_u] + gamma_i = Gi[batch_i] + gamma_j = Gi[batch_j] + feat_i = F[batch_i] + feat_j = F[batch_j] + + gamma_diff = gamma_i - gamma_j + feat_diff = feat_i - feat_j + + Xuij = ( + _inner(gamma_u, gamma_diff) + + _inner(gamma_u, feat_diff.mm(E)) + ) + + log_likelihood = torch.nn.functional.logsigmoid(Xuij).sum() + + # adversarial part + feat_i.retain_grad() + feat_j.retain_grad() + log_likelihood.backward(retain_graph=True) + feat_i_delta = feat_i.grad + feat_j_delta = feat_j.grad + + adv_feat_diff = feat_diff - (feat_i_delta - feat_j_delta) + adv_Xuij = ( + _inner(gamma_u, gamma_diff) + + _inner(gamma_u, adv_feat_diff.mm(E)) + ) + + adv_log_likelihood = torch.nn.functional.logsigmoid(adv_Xuij).sum() + + reg = ( + _l2_loss(gamma_u, gamma_i, gamma_j) * self.lambda_w + + _l2_loss(E) * self.lambda_e + ) + + loss = -log_likelihood - self.lambda_adv * adv_log_likelihood + reg + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + sum_loss += loss.data.item() + count += len(batch_u) + if count % (self.batch_size * 10) == 0: + progress_bar.set_postfix(loss=(sum_loss / count)) + progress_bar.update(1) + progress_bar.close() + + print("Optimization finished!") + + self.gamma_user = Gu.data.cpu().numpy() + self.gamma_item = Gi.data.cpu().numpy() + self.emb_matrix = E.data.cpu().numpy() + # pre-computed for faster evaluation + self.theta_item = F.mm(E).data.cpu().numpy() + + def score(self, user_idx, item_idx=None): + """Predict the scores/ratings of a user for an item. + + Parameters + ---------- + user_idx: int, required + The index of the user for whom to perform score prediction. + + item_idx: int, optional, default: None + The index of the item for which to perform score prediction. + If None, scores for all known items will be returned. + + Returns + ------- + res : A scalar or a Numpy array + Relative scores that the user gives to the item or to all known items + + """ + if item_idx is None: + known_item_scores = np.zeros(self.gamma_item.shape[0], dtype=np.float32) + fast_dot(self.gamma_user[user_idx], self.gamma_item, known_item_scores) + fast_dot(self.gamma_user[user_idx], self.theta_item, known_item_scores) + return known_item_scores + else: + item_score = np.dot(self.gamma_item[item_idx], self.gamma_user[user_idx]) + item_score += np.dot(self.theta_item[item_idx], self.gamma_user[user_idx]) + return item_score diff --git a/cornac/models/amr/requirements.txt b/cornac/models/amr/requirements.txt new file mode 100644 index 00000000..b54f97d6 --- /dev/null +++ b/cornac/models/amr/requirements.txt @@ -0,0 +1 @@ +torch>=0.4.1 diff --git a/docs/source/models.rst b/docs/source/models.rst index b4d39a2c..eae175be 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -27,6 +27,11 @@ Explainable Recommendation with Comparative Constraints on Product Aspects (Comp .. automodule:: cornac.models.comparer.recom_comparer_obj :members: + +Adversarial Training Towards Robust Multimedia Recommender System (AMR) +---------------------------------------------------- +.. automodule:: cornac.models.amr.recom_amr + :members: Collaborative Context Poisson Factorization (C2PF) ---------------------------------------------------- diff --git a/examples/amr_clothing.py b/examples/amr_clothing.py new file mode 100644 index 00000000..43150961 --- /dev/null +++ b/examples/amr_clothing.py @@ -0,0 +1,62 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Example for Adversarial Training Towards Robust Multimedia Recommender System +""" + +import cornac +from cornac.datasets import amazon_clothing +from cornac.data import ImageModality +from cornac.eval_methods import RatioSplit + + +# CausalRec utilises the causal inference to debias the visual bias +# The necessary data can be loaded as follows +feedback = amazon_clothing.load_feedback() +features, item_ids = amazon_clothing.load_visual_feature() # BIG file + +# Instantiate a ImageModality, it makes it convenient to work with visual auxiliary information +# For more details, please refer to the tutorial on how to work with auxiliary data +item_image_modality = ImageModality(features=features, ids=item_ids, normalized=True) + +# Define an evaluation method to split feedback into train and test sets +ratio_split = RatioSplit( + data=feedback, + test_size=0.1, + rating_threshold=0.5, + exclude_unknowns=True, + verbose=True, + item_image=item_image_modality, +) + +# Instantiate AMR +amr = cornac.models.AMR( + k=32, + k2=32, + n_epochs=1, + batch_size=100, + learning_rate=0.001, + lambda_w=1, + lambda_b=0.01, + lambda_e=0.0, + lmd=1.0, + use_gpu=True, +) + +# Instantiate evaluation measures +rec_50 = cornac.metrics.Recall(k=50) + +# Put everything together into an experiment and run it +cornac.Experiment(eval_method=ratio_split, models=[amr], metrics=[rec_50]).run() diff --git a/examples/causalrec_clothing.py b/examples/causalrec_clothing.py index b00cc943..011036ed 100644 --- a/examples/causalrec_clothing.py +++ b/examples/causalrec_clothing.py @@ -52,9 +52,10 @@ causalrec = cornac.models.CausalRec( lambda_w=1, lambda_b=0.01, lambda_e=0.0, - use_gpu=True, mean_feat=features.mean(axis=0), - tanh=1 + tanh=1, + lambda_2=0.8, + use_gpu=True, ) # Instantiate evaluation measures -- GitLab