From 2be4463bcc6dcc07a4d8edf7a01f236968abe344 Mon Sep 17 00:00:00 2001
From: Qrh <17882988+RuihongQiu@users.noreply.github.com>
Date: Fri, 9 Jul 2021 22:00:16 +1000
Subject: [PATCH] Add AMR model (#420)

---
 README.md                          |   1 +
 cornac/models/__init__.py          |   1 +
 cornac/models/amr/__init__.py      |  16 ++
 cornac/models/amr/recom_amr.py     | 298 +++++++++++++++++++++++++++++
 cornac/models/amr/requirements.txt |   1 +
 docs/source/models.rst             |   5 +
 examples/amr_clothing.py           |  62 ++++++
 examples/causalrec_clothing.py     |   5 +-
 8 files changed, 387 insertions(+), 2 deletions(-)
 create mode 100644 cornac/models/amr/__init__.py
 create mode 100644 cornac/models/amr/recom_amr.py
 create mode 100644 cornac/models/amr/requirements.txt
 create mode 100644 examples/amr_clothing.py

diff --git a/README.md b/README.md
index 9ce625cf..e939d364 100644
--- a/README.md
+++ b/README.md
@@ -109,6 +109,7 @@ The recommender models supported by Cornac are listed below. Why don't you join
 | 2021 | [Bilateral Variational Autoencoder for Collaborative Filtering (BiVAECF)](cornac/models/bivaecf), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441759) | [requirements.txt](cornac/models/bivaecf/requirements.txt) | [PreferredAI/bi-vae](https://github.com/PreferredAI/bi-vae)
 |      | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [paper](https://arxiv.org/abs/2107.02390) | [requirements.txt](cornac/models/causalrec/requirements.txt) | [causalrec_clothing.py](examples/causalrec_clothing.py)
 |      | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | N/A | [PreferredAI/ComparER](https://github.com/PreferredAI/ComparER)
+| 2020 | [Adversarial Training Towards Robust Multimedia Recommender System (AMR)](cornac/models/amr), [paper](https://ieeexplore.ieee.org/document/8618394) | [requirements.txt](cornac/models/amr/requirements.txt) | [amr_clothing.py](examples/amr_clothing.py)
 | 2018 | [Collaborative Context Poisson Factorization (C2PF)](cornac/models/c2pf), [paper](https://www.ijcai.org/proceedings/2018/0370.pdf) | N/A | [c2pf_exp.py](examples/c2pf_example.py)
 |      | [Multi-Task Explainable Recommendation (MTER)](cornac/models/mter), [paper](https://arxiv.org/pdf/1806.03568.pdf) | N/A | [mter_exp.py](examples/mter_example.py)
 |      | [Neural Attention Rating Regression with Review-level Explanations (NARRE)](cornac/models/narre), [paper](http://www.thuir.cn/group/~YQLiu/publications/WWW2018_CC.pdf) | [requirements.txt](cornac/models/narre/requirements.txt) | [narre_example.py](examples/narre_example.py)
diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py
index d032090a..85dcf275 100644
--- a/cornac/models/__init__.py
+++ b/cornac/models/__init__.py
@@ -15,6 +15,7 @@
 
 from .recommender import Recommender
 
+from .amr import AMR
 from .baseline_only import BaselineOnly
 from .bivaecf import BiVAECF
 from .bpr import BPR
diff --git a/cornac/models/amr/__init__.py b/cornac/models/amr/__init__.py
new file mode 100644
index 00000000..14c15d2f
--- /dev/null
+++ b/cornac/models/amr/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2018 The Cornac Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from .recom_amr import AMR
\ No newline at end of file
diff --git a/cornac/models/amr/recom_amr.py b/cornac/models/amr/recom_amr.py
new file mode 100644
index 00000000..1ef773df
--- /dev/null
+++ b/cornac/models/amr/recom_amr.py
@@ -0,0 +1,298 @@
+# Copyright 2018 The Cornac Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import numpy as np
+from tqdm.auto import tqdm
+
+from ..recommender import Recommender
+from ...exception import CornacException
+from ...exception import ScoreException
+from ...utils import fast_dot
+from ...utils.common import intersects
+from ...utils import get_rng
+from ...utils.init_utils import zeros, xavier_uniform
+
+
+class AMR(Recommender):
+    """Adversarial Training Towards Robust Multimedia Recommender System.
+
+    Parameters
+    ----------
+    k: int, optional, default: 10
+        The dimension of the gamma latent factors.
+
+    k2: int, optional, default: 10
+        The dimension of the theta latent factors.
+
+    n_epochs: int, optional, default: 20
+        Maximum number of epochs for SGD.
+
+    batch_size: int, optional, default: 100
+        The batch size for SGD.
+
+    learning_rate: float, optional, default: 0.001
+        The learning rate for SGD.
+
+    lambda_w: float, optional, default: 0.01
+        The regularization hyper-parameter for latent factor weights.
+
+    lambda_b: float, optional, default: 0.01
+        The regularization hyper-parameter for biases.
+
+    lambda_e: float, optional, default: 0.0
+        The regularization hyper-parameter for embedding matrix E and beta prime vector.
+
+    lambda_adv: float, optional, default: 1.0
+        The regularization hyper-parameter in Eq. (8) and (10) for the adversarial sample loss.
+
+    use_gpu: boolean, optional, default: True
+        Whether or not to use GPU to speed up training.
+
+    trainable: boolean, optional, default: True
+        When False, the model is not trained and Cornac assumes that the model already \
+        pre-trained (U and V are not None).
+
+    verbose: boolean, optional, default: True
+        When True, running logs are displayed.
+
+    init_params: dictionary, optional, default: None
+        Initial parameters, e.g., init_params = {'Bi': beta_item, 'Gu': gamma_user,
+        'Gi': gamma_item, 'Tu': theta_user, 'E': emb_matrix, 'Bp': beta_prime}
+
+    seed: int, optional, default: None
+        Random seed for weight initialization.
+
+    References
+    ----------
+    * Tang, J., Du, X., He, X., Yuan, F., Tian, Q., and Chua, T. (2020). Adversarial Training Towards Robust Multimedia Recommender System.
+    """
+    
+    def __init__(
+            self,
+            name="AMR",
+            k=10,
+            k2=10,
+            n_epochs=50,
+            batch_size=100,
+            learning_rate=0.005,
+            lambda_w=0.01,
+            lambda_b=0.01,
+            lambda_e=0.0,
+            lambda_adv=1.0,
+            use_gpu=False,
+            trainable=True,
+            verbose=True,
+            init_params=None,
+            seed=None,
+    ):
+        super().__init__(name=name, trainable=trainable, verbose=verbose)
+        self.k = k
+        self.k2 = k2
+        self.n_epochs = n_epochs
+        self.batch_size = batch_size
+        self.learning_rate = learning_rate
+        self.lambda_w = lambda_w
+        self.lambda_b = lambda_b
+        self.lambda_e = lambda_e
+        self.lambda_adv = lambda_adv
+        self.use_gpu = use_gpu
+        self.seed = seed
+        
+        # Init params if provided
+        self.init_params = {} if init_params is None else init_params
+        self.gamma_user = self.init_params.get("Gu", None)
+        self.gamma_item = self.init_params.get("Gi", None)
+        self.emb_matrix = self.init_params.get("E", None)
+    
+    def _init(self, n_users, n_items, features):
+        rng = get_rng(self.seed)
+        
+        if self.gamma_user is None:
+            self.gamma_user = xavier_uniform((n_users, self.k), rng)
+        if self.gamma_item is None:
+            self.gamma_item = xavier_uniform((n_items, self.k), rng)
+        if self.emb_matrix is None:
+            self.emb_matrix = xavier_uniform((features.shape[1], self.k), rng)
+        
+        # pre-computed for faster evaluation
+        self.theta_item = np.matmul(features, self.emb_matrix)
+    
+    def fit(self, train_set, val_set=None):
+        """Fit the model to observations.
+
+        Parameters
+        ----------
+        train_set: :obj:`cornac.data.Dataset`, required
+            User-Item preference data as well as additional modalities.
+
+        val_set: :obj:`cornac.data.Dataset`, optional, default: None
+            User-Item preference data for model selection purposes (e.g., early stopping).
+
+        Returns
+        -------
+        self : object
+        """
+        Recommender.fit(self, train_set, val_set)
+        
+        if train_set.item_image is None:
+            raise CornacException("item_image modality is required but None.")
+        
+        # Item visual feature from CNN
+        train_features = train_set.item_image.features[: self.train_set.total_items]
+        train_features = train_features.astype(np.float32)
+        self._init(
+            n_users=train_set.total_users,
+            n_items=train_set.total_items,
+            features=train_features,
+        )
+        
+        if self.trainable:
+            self._fit_torch(train_features)
+        
+        return self
+    
+    def _fit_torch(self, train_features):
+        import torch
+        
+        def _l2_loss(*tensors):
+            l2_loss = 0
+            for tensor in tensors:
+                l2_loss += tensor.pow(2).sum()
+            return l2_loss / 2
+        
+        def _inner(a, b):
+            return (a * b).sum(dim=1)
+        
+        dtype = torch.float
+        device = (
+            torch.device("cuda:0")
+            if (self.use_gpu and torch.cuda.is_available())
+            else torch.device("cpu")
+        )
+        
+        # set requireds_grad=True to get the adversarial gradient
+        # if F is not put into the optimization list of parameters
+        # it won't be updated
+        F = torch.tensor(
+            train_features, device=device, dtype=dtype, requires_grad=True
+        )
+        # Learned parameters
+        Gu = torch.tensor(
+            self.gamma_user, device=device, dtype=dtype, requires_grad=True
+        )
+        Gi = torch.tensor(
+            self.gamma_item, device=device, dtype=dtype, requires_grad=True
+        )
+        E = torch.tensor(
+            self.emb_matrix, device=device, dtype=dtype, requires_grad=True
+        )
+        
+        optimizer = torch.optim.Adam([Gu, Gi, E], lr=self.learning_rate)
+        
+        for epoch in range(1, self.n_epochs + 1):
+            sum_loss = 0.0
+            count = 0
+            progress_bar = tqdm(
+                total=self.train_set.num_batches(self.batch_size),
+                desc="Epoch {}/{}".format(epoch, self.n_epochs),
+                disable=not self.verbose,
+            )
+            for batch_u, batch_i, batch_j in self.train_set.uij_iter(
+                    self.batch_size, shuffle=True
+            ):
+                gamma_u = Gu[batch_u]
+                gamma_i = Gi[batch_i]
+                gamma_j = Gi[batch_j]
+                feat_i = F[batch_i]
+                feat_j = F[batch_j]
+                
+                gamma_diff = gamma_i - gamma_j
+                feat_diff = feat_i - feat_j
+                
+                Xuij = (
+                        _inner(gamma_u, gamma_diff)
+                        + _inner(gamma_u, feat_diff.mm(E))
+                )
+                
+                log_likelihood = torch.nn.functional.logsigmoid(Xuij).sum()
+                
+                # adversarial part
+                feat_i.retain_grad()
+                feat_j.retain_grad()
+                log_likelihood.backward(retain_graph=True)
+                feat_i_delta = feat_i.grad
+                feat_j_delta = feat_j.grad
+                
+                adv_feat_diff = feat_diff - (feat_i_delta - feat_j_delta)
+                adv_Xuij = (
+                        _inner(gamma_u, gamma_diff)
+                        + _inner(gamma_u, adv_feat_diff.mm(E))
+                )
+                
+                adv_log_likelihood = torch.nn.functional.logsigmoid(adv_Xuij).sum()
+                
+                reg = (
+                        _l2_loss(gamma_u, gamma_i, gamma_j) * self.lambda_w
+                        + _l2_loss(E) * self.lambda_e
+                )
+                
+                loss = -log_likelihood - self.lambda_adv * adv_log_likelihood + reg
+                
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
+                
+                sum_loss += loss.data.item()
+                count += len(batch_u)
+                if count % (self.batch_size * 10) == 0:
+                    progress_bar.set_postfix(loss=(sum_loss / count))
+                progress_bar.update(1)
+            progress_bar.close()
+        
+        print("Optimization finished!")
+        
+        self.gamma_user = Gu.data.cpu().numpy()
+        self.gamma_item = Gi.data.cpu().numpy()
+        self.emb_matrix = E.data.cpu().numpy()
+        # pre-computed for faster evaluation
+        self.theta_item = F.mm(E).data.cpu().numpy()
+    
+    def score(self, user_idx, item_idx=None):
+        """Predict the scores/ratings of a user for an item.
+
+        Parameters
+        ----------
+        user_idx: int, required
+            The index of the user for whom to perform score prediction.
+
+        item_idx: int, optional, default: None
+            The index of the item for which to perform score prediction.
+            If None, scores for all known items will be returned.
+
+        Returns
+        -------
+        res : A scalar or a Numpy array
+            Relative scores that the user gives to the item or to all known items
+
+        """
+        if item_idx is None:
+            known_item_scores = np.zeros(self.gamma_item.shape[0], dtype=np.float32)
+            fast_dot(self.gamma_user[user_idx], self.gamma_item, known_item_scores)
+            fast_dot(self.gamma_user[user_idx], self.theta_item, known_item_scores)
+            return known_item_scores
+        else:
+            item_score = np.dot(self.gamma_item[item_idx], self.gamma_user[user_idx])
+            item_score += np.dot(self.theta_item[item_idx], self.gamma_user[user_idx])
+            return item_score
diff --git a/cornac/models/amr/requirements.txt b/cornac/models/amr/requirements.txt
new file mode 100644
index 00000000..b54f97d6
--- /dev/null
+++ b/cornac/models/amr/requirements.txt
@@ -0,0 +1 @@
+torch>=0.4.1
diff --git a/docs/source/models.rst b/docs/source/models.rst
index b4d39a2c..eae175be 100644
--- a/docs/source/models.rst
+++ b/docs/source/models.rst
@@ -27,6 +27,11 @@ Explainable Recommendation with Comparative Constraints on Product Aspects (Comp
    
 .. automodule:: cornac.models.comparer.recom_comparer_obj
    :members:
+   
+Adversarial Training Towards Robust Multimedia Recommender System (AMR)
+----------------------------------------------------
+.. automodule:: cornac.models.amr.recom_amr
+   :members:
 
 Collaborative Context Poisson Factorization (C2PF)
 ----------------------------------------------------
diff --git a/examples/amr_clothing.py b/examples/amr_clothing.py
new file mode 100644
index 00000000..43150961
--- /dev/null
+++ b/examples/amr_clothing.py
@@ -0,0 +1,62 @@
+# Copyright 2018 The Cornac Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Example for Adversarial Training Towards Robust Multimedia Recommender System
+"""
+
+import cornac
+from cornac.datasets import amazon_clothing
+from cornac.data import ImageModality
+from cornac.eval_methods import RatioSplit
+
+
+# CausalRec utilises the causal inference to debias the visual bias
+# The necessary data can be loaded as follows
+feedback = amazon_clothing.load_feedback()
+features, item_ids = amazon_clothing.load_visual_feature()  # BIG file
+
+# Instantiate a ImageModality, it makes it convenient to work with visual auxiliary information
+# For more details, please refer to the tutorial on how to work with auxiliary data
+item_image_modality = ImageModality(features=features, ids=item_ids, normalized=True)
+
+# Define an evaluation method to split feedback into train and test sets
+ratio_split = RatioSplit(
+    data=feedback,
+    test_size=0.1,
+    rating_threshold=0.5,
+    exclude_unknowns=True,
+    verbose=True,
+    item_image=item_image_modality,
+)
+
+# Instantiate AMR
+amr = cornac.models.AMR(
+    k=32,
+    k2=32,
+    n_epochs=1,
+    batch_size=100,
+    learning_rate=0.001,
+    lambda_w=1,
+    lambda_b=0.01,
+    lambda_e=0.0,
+    lmd=1.0,
+    use_gpu=True,
+)
+
+# Instantiate evaluation measures
+rec_50 = cornac.metrics.Recall(k=50)
+
+# Put everything together into an experiment and run it
+cornac.Experiment(eval_method=ratio_split, models=[amr], metrics=[rec_50]).run()
diff --git a/examples/causalrec_clothing.py b/examples/causalrec_clothing.py
index b00cc943..011036ed 100644
--- a/examples/causalrec_clothing.py
+++ b/examples/causalrec_clothing.py
@@ -52,9 +52,10 @@ causalrec = cornac.models.CausalRec(
     lambda_w=1,
     lambda_b=0.01,
     lambda_e=0.0,
-    use_gpu=True,
     mean_feat=features.mean(axis=0),
-    tanh=1
+    tanh=1,
+    lambda_2=0.8,
+    use_gpu=True,
 )
 
 # Instantiate evaluation measures
-- 
GitLab