From 296d2d90f654c27ffa914f933e0bd82bc48d38c5 Mon Sep 17 00:00:00 2001
From: Max Beckers <77118862+mabeckers@users.noreply.github.com>
Date: Wed, 20 Mar 2024 09:58:12 -0700
Subject: [PATCH] Add DMRL model (#597)

---
 README.md                                     |   1 +
 cornac/models/__init__.py                     |   1 +
 cornac/models/dmrl/__init__.py                |   1 +
 cornac/models/dmrl/d_cor_calc.py              | 116 ++++
 cornac/models/dmrl/dmrl.py                    | 321 ++++++++++
 cornac/models/dmrl/pwlearning_sampler.py      | 116 ++++
 cornac/models/dmrl/recom_dmrl.py              | 550 ++++++++++++++++++
 cornac/models/dmrl/requirements.txt           |   7 +
 cornac/models/dmrl/transformer_text.py        |  98 ++++
 cornac/models/dmrl/transformer_vision.py      | 149 +++++
 docs/source/api_ref/models.rst                |   5 +
 examples/README.md                            |   4 +
 examples/dmrl_clothes_example.py              |  59 ++
 examples/dmrl_example.py                      |  50 ++
 pytest.ini                                    |   3 +-
 tests/cornac/models/__init__.py               |   0
 tests/cornac/models/dmrl/__init__.py          |   0
 .../cornac/models/dmrl/test_distance_calc.py  |  47 ++
 .../models/dmrl/test_pwlearning_sampler.py    | 100 ++++
 .../models/dmrl/test_transformertext.py       |  53 ++
 .../models/dmrl/test_transformervision.py     | 112 ++++
 21 files changed, 1792 insertions(+), 1 deletion(-)
 create mode 100644 cornac/models/dmrl/__init__.py
 create mode 100644 cornac/models/dmrl/d_cor_calc.py
 create mode 100644 cornac/models/dmrl/dmrl.py
 create mode 100644 cornac/models/dmrl/pwlearning_sampler.py
 create mode 100644 cornac/models/dmrl/recom_dmrl.py
 create mode 100644 cornac/models/dmrl/requirements.txt
 create mode 100644 cornac/models/dmrl/transformer_text.py
 create mode 100644 cornac/models/dmrl/transformer_vision.py
 create mode 100644 examples/dmrl_clothes_example.py
 create mode 100644 examples/dmrl_example.py
 create mode 100644 tests/cornac/models/__init__.py
 create mode 100644 tests/cornac/models/dmrl/__init__.py
 create mode 100644 tests/cornac/models/dmrl/test_distance_calc.py
 create mode 100644 tests/cornac/models/dmrl/test_pwlearning_sampler.py
 create mode 100644 tests/cornac/models/dmrl/test_transformertext.py
 create mode 100644 tests/cornac/models/dmrl/test_transformervision.py

diff --git a/README.md b/README.md
index 93406bd9..b25e8d8a 100644
--- a/README.md
+++ b/README.md
@@ -148,6 +148,7 @@ The recommender models supported by Cornac are listed below. Why don't you join
 | Year | Model and paper | Model type | Require-ments | Examples |
 | :---: | --- | :---: | :---: | :---: |
 | 2024 | [Hypergraphs with Attention on Reviews (HypAR)](cornac/models/hypar), [paper](https://doi.org/10.1007/978-3-031-56027-9_14)| Hybrid / Sentiment / Explainable | [reqs](cornac/models/hypar/requirements_cu116.txt) | [exp](https://github.com/PreferredAI/HypAR)
+| 2022 | [Disentangled Multimodal Representation Learning for Recommendation (DMRL)](cornac/models/dmrl), [paper](https://arxiv.org/pdf/2203.05406.pdf) | Content-Based / Text & Image | [reqs](cornac/models/dmrl/requirements.txt) | [exp](examples/dmrl_example.py)
 | 2021 | [Bilateral Variational Autoencoder for Collaborative Filtering (BiVAECF)](cornac/models/bivaecf), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441759) | Collaborative Filtering / Content-Based | [reqs](cornac/models/bivaecf/requirements.txt) | [exp](https://github.com/PreferredAI/bi-vae)
 |      | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [paper](https://arxiv.org/abs/2107.02390) | Content-Based / Image | [reqs](cornac/models/causalrec/requirements.txt) | [exp](examples/causalrec_clothing.py)
 |      | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | Explainable | N/A | [exp](https://github.com/PreferredAI/ComparER)
diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py
index f4256b81..6505ee25 100644
--- a/cornac/models/__init__.py
+++ b/cornac/models/__init__.py
@@ -38,6 +38,7 @@ from .conv_mf import ConvMF
 from .ctr import CTR
 from .cvae import CVAE
 from .cvaecf import CVAECF
+from .dmrl import DMRL
 from .dnntsp import DNNTSP
 from .ease import EASE
 from .efm import EFM
diff --git a/cornac/models/dmrl/__init__.py b/cornac/models/dmrl/__init__.py
new file mode 100644
index 00000000..e3c7a1c1
--- /dev/null
+++ b/cornac/models/dmrl/__init__.py
@@ -0,0 +1 @@
+from .recom_dmrl import DMRL
\ No newline at end of file
diff --git a/cornac/models/dmrl/d_cor_calc.py b/cornac/models/dmrl/d_cor_calc.py
new file mode 100644
index 00000000..9094a9cf
--- /dev/null
+++ b/cornac/models/dmrl/d_cor_calc.py
@@ -0,0 +1,116 @@
+# Copyright 2018 The Cornac Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import torch
+
+
+class DistanceCorrelationCalculator:
+    """
+    Calculates the disentangled loss for DMRL model.
+    Please see https://arxiv.org/pdf/2203.05406.pdf for more details.
+    """
+
+    def __init__(self, n_factors, num_neg) -> None:
+        self.n_factors = n_factors
+        self.num_neg = num_neg
+
+    def calculate_cov(self, X, Y):
+        """
+        Computes the distance covariance between X and Y.
+        :param X: A 3D torch tensor.
+        :param Y: A 3D torch tensor.
+        :return: A 1D torch tensor of len 1+num_neg.
+        """
+        # first create centered distance matrices
+        X = self.cent_dist(X)
+        Y = self.cent_dist(Y)
+
+        # batch_size is dim 1, as dim 0 is one positive and num_neg negative samples
+        n_samples = X.shape[1]
+        # then calculate the covariance as a 1D array of length 1+num_neg
+        cov = torch.sqrt(torch.max(torch.sum(X * Y, dim=(1, 2)) / (n_samples * n_samples), torch.tensor(1e-5)))
+        return cov
+
+    def calculate_var(self, X):
+        """
+        Computes the distance variance of X.
+        :param X: A 3D torch tensor.
+        :return: A 1D torch tensor of len 1+mum_neg.
+        """
+        return self.calculate_cov(X, X)
+
+    def calculate_cor(self, X, Y):
+        """
+        Computes the distance correlation between X and Y.
+
+        :param X: A 3D torch tensor.
+        :param Y: A 3D torch tensor.
+        :return: A 1D torch tensor of len 1+mum_neg.
+        """
+        return self.calculate_cov(X, Y) / torch.sqrt(torch.max(self.calculate_var(X) * self.calculate_var(Y), torch.tensor(0.0)))
+
+    def cent_dist(self, X):
+        """
+        Computes the pairwise euclidean distance between rows of X and centers
+        each cell of the distance matrix with row mean, column mean, and grand mean.
+        """
+        # put the samples from dim 1 into dim 0
+        X = torch.transpose(X, dim0=0, dim1=1)
+
+        # Now use pythagoras to calculate the distance matrix
+        first_part = torch.sum(torch.square(X), dim=-1, keepdims=True)
+        middle_part = torch.matmul(X, torch.transpose(X, dim0=1, dim1=2))
+        last_part = torch.transpose(first_part, dim0=1, dim1=2)
+
+        D = torch.sqrt(torch.max(first_part - 2 * middle_part + last_part, torch.tensor(1e-5)))
+        # dim0 is the negative samples, dim1 is batch_size, dim2 is the kth factor of the embedding_dim
+
+        row_mean = torch.mean(D, dim=2, keepdim=True)
+        column_mean = torch.mean(D, dim=1, keepdim=True)
+        global_mean = torch.mean(D, dim=(1, 2), keepdim=True)
+        D = D - row_mean - column_mean + global_mean
+        return D
+
+    def  calculate_disentangled_loss(
+            self,
+            item_embedding_factors: torch.Tensor,
+            user_embedding_factors: torch.Tensor,
+            text_embedding_factors: torch.Tensor,
+            image_embedding_factors: torch.Tensor):
+        """
+        Calculates the disentangled loss for the given factors.
+
+        :param item_embedding_factors: A list of 3D torch tensors.
+        :param user_embedding_factors: A list of 3D torch tensors.
+        :param text_embedding_factors: A list of 3D torch tensors.
+        :return: A 1D torch tensor of len 1+mum_neg.
+        """
+        cor_loss = torch.tensor([0.0] * (1 + self.num_neg))
+        for i in range(0, self.n_factors - 2):
+            for j in range(i + 1, self.n_factors - 1):
+                cor_loss += self.calculate_cor(item_embedding_factors[i], item_embedding_factors[j])
+                cor_loss += self.calculate_cor(user_embedding_factors[i], user_embedding_factors[j])
+                if text_embedding_factors[i].numel() > 0:
+                    cor_loss += self.calculate_cor(text_embedding_factors[i], text_embedding_factors[j])
+                if image_embedding_factors[i].numel() > 0:
+                    cor_loss += self.calculate_cor(image_embedding_factors[i], image_embedding_factors[j])
+
+        cor_loss = cor_loss / ((self.n_factors + 1.0) * self.n_factors / 2)
+
+        # two options, we can either return the sum over the 1 positive and num_neg negative samples.
+        # or we can return only the loss of the one positive sample, as they did in the paper
+
+        # return torch.sum(cor_loss)
+        return cor_loss[0]
diff --git a/cornac/models/dmrl/dmrl.py b/cornac/models/dmrl/dmrl.py
new file mode 100644
index 00000000..19841e0d
--- /dev/null
+++ b/cornac/models/dmrl/dmrl.py
@@ -0,0 +1,321 @@
+# Copyright 2018 The Cornac Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from typing import List, Tuple
+import torch
+import torch.nn as nn
+from cornac.models.dmrl.d_cor_calc import DistanceCorrelationCalculator
+from dataclasses import dataclass
+
+from cornac.utils.common import get_rng
+from cornac.utils.init_utils import normal, xavier_normal, xavier_uniform
+
+
+@dataclass
+class EmbeddingFactorLists:
+    """
+    A dataclass for holding the embedding factors for each modality.
+    """
+
+    user_embedding_factors: List[torch.Tensor]
+    item_embedding_factors: List[torch.Tensor]
+    text_embedding_factors: List[torch.Tensor] = None
+    image_embedding_factors: List[torch.Tensor] = None
+
+
+class DMRLModel(nn.Module):
+    """
+    The actual Disentangled Multi-Modal Recommendation Model neural network.
+    """
+
+    def __init__(
+        self,
+        num_users: int,
+        num_items: int,
+        embedding_dim: int,
+        text_dim: int,
+        image_dim: int,
+        dropout: float,
+        num_neg: int,
+        num_factors: int,
+        seed: int = 123,
+    ):
+
+        super(DMRLModel, self).__init__()
+        self.device = "cuda" if torch.cuda.is_available() else "cpu"
+        self.num_factors = num_factors
+        self.num_neg = num_neg
+        self.embedding_dim = embedding_dim
+        self.num_modalities = 1 + bool(text_dim) + bool(image_dim)
+        self.dropout = dropout
+        self.grad_norms = []
+        self.param_norms = []
+        self.ui_ratings = []
+        self.ut_ratings = []
+        self.ui_attention = []
+        self.ut_attention = []
+
+        rng = get_rng(123)
+
+        if text_dim:
+            self.text_module = torch.nn.Sequential(
+                torch.nn.Dropout(p=self.dropout),
+                torch.nn.Linear(text_dim, 150),
+                torch.nn.LeakyReLU(),
+                torch.nn.Dropout(p=self.dropout),
+                torch.nn.Linear(150, embedding_dim),
+                torch.nn.LeakyReLU(),
+            )
+            self.text_module[1].weight.data = torch.from_numpy(
+                xavier_normal([150, text_dim], random_state=rng)
+            )  # , std=0.02))
+            self.text_module[4].weight.data
+
+        if image_dim:
+            self.image_module = torch.nn.Sequential(
+                torch.nn.Dropout(p=self.dropout),
+                torch.nn.Linear(image_dim, 150),
+                torch.nn.LeakyReLU(),
+                torch.nn.Dropout(p=self.dropout),
+                torch.nn.Linear(150, embedding_dim),
+                torch.nn.LeakyReLU(),
+            )
+
+        self.user_embedding = torch.nn.Embedding(num_users, embedding_dim)
+        self.item_embedding = torch.nn.Embedding(num_items, embedding_dim)
+
+        self.user_embedding.weight.data = torch.from_numpy(
+            xavier_normal([num_users, embedding_dim], random_state=rng)
+        )  # , std=0.02))
+        self.item_embedding.weight.data = torch.from_numpy(
+            xavier_normal([num_items, embedding_dim], random_state=rng)
+        )  # , std=0.02))
+
+        self.factor_size = self.embedding_dim // self.num_factors
+
+        self.attention_layer = torch.nn.Sequential(
+            torch.nn.Dropout(p=self.dropout),
+            torch.nn.Linear(
+                (self.num_modalities + 1) * self.factor_size, self.num_modalities
+            ),
+            torch.nn.Tanh(),
+            torch.nn.Dropout(p=self.dropout),
+            torch.nn.Linear(self.num_modalities, self.num_modalities, bias=False),
+            torch.nn.Softmax(dim=-1),
+        )
+        self.attention_layer[1].weight.data = torch.from_numpy(
+            xavier_normal(
+                [self.num_modalities, (self.num_modalities + 1) * self.factor_size],
+                random_state=rng,
+            )
+        )  # , std=0.02))
+        self.attention_layer[4].weight.data = torch.from_numpy(
+            xavier_normal([self.num_modalities, self.num_modalities], random_state=rng)
+        )  # , std=0.02))
+
+        self.grad_dict = {i[0]: [] for i in self.named_parameters()}
+
+    def forward(
+        self, batch: torch.Tensor, text: torch.Tensor, image: torch.Tensor
+    ) -> Tuple[EmbeddingFactorLists, torch.Tensor]:
+        """
+        Forward pass of the model.
+
+        Parameters:
+        -----------
+        batch: torch.Tensor
+            A batch of data. The first column contains the user indices, the
+            rest of the columns contain the item indices (one pos and num_neg negatives)
+        text: torch.Tensor
+            The text data for the items in the batch (encoded)
+        image: torch.Tensor
+            The image data for the items in the batch (encoded)
+        """
+        text_embedding_factors = [
+            torch.tensor([]).to(self.device) for _ in range(self.num_factors)
+        ]
+        image_embedding_factors = [
+            torch.tensor([]).to(self.device) for _ in range(self.num_factors)
+        ]
+        users = batch[:, 0]
+        items = batch[:, 1:]
+
+        # handle text
+        if text is not None:
+            text_embedding = self.text_module(
+                torch.nn.functional.normalize(text, dim=-1)
+            )
+            text_embedding_factors = torch.split(
+                text_embedding, self.embedding_dim // self.num_factors, dim=-1
+            )
+
+        # handle image
+        if image is not None:
+            image_embedding = self.image_module(
+                torch.nn.functional.normalize(image, dim=-1)
+            )
+            image_embedding_factors = torch.split(
+                image_embedding, self.embedding_dim // self.num_factors, dim=-1
+            )
+
+        # handle users
+        user_embedding = self.user_embedding(users)
+        # we have to get users into shape batch, 1+num_neg, embedding_dim
+        # therefore we repeat the users across the 1 pos and num_neg items
+        user_embedding_inflated = user_embedding.unsqueeze(1).repeat(
+            1, items.shape[1], 1
+        )
+        user_embedding_factors = torch.split(
+            user_embedding_inflated, self.embedding_dim // self.num_factors, dim=-1
+        )
+
+        # handle items
+        item_embedding = self.item_embedding(items)
+        item_embedding_factors = torch.split(
+            item_embedding, self.embedding_dim // self.num_factors, dim=-1
+        )
+
+        embedding_factor_lists = EmbeddingFactorLists(
+            user_embedding_factors,
+            item_embedding_factors,
+            text_embedding_factors,
+            image_embedding_factors,
+        )
+
+        # attentionLayer: implemented per factor k
+        batch_size = users.shape[0]
+        ratings_sum_over_mods = torch.zeros((batch_size, 1 + self.num_neg)).to(
+            self.device
+        )
+        for i in range(self.num_factors):
+
+            concatted_features = torch.concatenate(
+                [
+                    user_embedding_factors[i],
+                    item_embedding_factors[i],
+                    text_embedding_factors[i],
+                    image_embedding_factors[i],
+                ],
+                axis=2,
+            )
+            attention = self.attention_layer(
+                torch.nn.functional.normalize(concatted_features, dim=-1)
+            )
+
+            r_ui = attention[:, :, 0] * torch.nn.Softplus()(
+                torch.sum(
+                    user_embedding_factors[i] * item_embedding_factors[i], axis=-1
+                )
+            )
+            # log rating
+            self.ui_ratings.append(torch.norm(r_ui.detach().flatten()).cpu())
+
+            factor_rating = r_ui
+
+            if text is not None:
+                r_ut = attention[:, :, 1] * torch.nn.Softplus()(
+                    torch.sum(
+                        user_embedding_factors[i] * text_embedding_factors[i], axis=-1
+                    )
+                )
+                factor_rating = factor_rating + r_ut
+                # log rating
+                self.ut_ratings.append(torch.norm(r_ut.detach().flatten()).cpu())
+
+            if image is not None:
+                r_ui = attention[:, :, 1] * torch.nn.Softplus()(
+                    torch.sum(
+                        user_embedding_factors[i] * image_embedding_factors[i], axis=-1
+                    )
+                )
+                factor_rating = factor_rating + r_ui
+                self.ui_ratings.append(torch.norm(r_ui.detach().flatten()).cpu())
+
+            # sum up over modalities and running sum over factors
+            ratings_sum_over_mods = ratings_sum_over_mods + factor_rating
+
+        return embedding_factor_lists, ratings_sum_over_mods
+
+    def log_gradients_and_weights(self):
+        """
+        Stores most recent gradient norms in a list.
+        """
+
+        for i in self.named_parameters():
+            self.grad_dict[i[0]].append(torch.norm(i[1].grad.detach().flatten()).item())
+
+        total_norm_grad = torch.norm(
+            torch.cat([p.grad.detach().flatten() for p in self.parameters()])
+        )
+        self.grad_norms.append(total_norm_grad.item())
+
+        total_norm_param = torch.norm(
+            torch.cat([p.detach().flatten() for p in self.parameters()])
+        )
+        self.param_norms.append(total_norm_param.item())
+
+    def reset_grad_metrics(self):
+        """
+        Reset the gradient metrics.
+        """
+        self.grad_norms = []
+        self.param_norms = []
+        self.grad_dict = {i[0]: [] for i in self.named_parameters()}
+        self.ui_ratings = []
+        self.ut_ratings = []
+        self.ut_attention = []
+        self.ut_attention = []
+
+
+class DMRLLoss(nn.Module):
+    """
+    The disentangled multi-modal recommendation model loss function. It's a
+    combination of pairwise based ranking loss and disentangled loss. For
+    details see DMRL paper.
+    """
+
+    def __init__(self, decay_c, num_factors, num_neg):
+        super(DMRLLoss, self).__init__()
+        self.decay_c = decay_c
+        self.distance_cor_calc = DistanceCorrelationCalculator(
+            n_factors=num_factors, num_neg=num_neg
+        )
+
+    def forward(
+        self, embedding_factor_lists: EmbeddingFactorLists, rating_scores: torch.tensor
+    ) -> torch.tensor:
+        """
+        Calculates the loss for the batch of data.
+        """
+        r_pos = rating_scores[:, 0]
+        # from the num_neg many negative sampled items, we want to find the one
+        # with the largest score to have one negative sample per user in our
+        # batch
+        r_neg = torch.max(rating_scores[:, 1:], dim=1).values
+
+        # define the ranking loss for pairwise-based ranking approach
+        loss_BPR = torch.sum(torch.nn.Softplus()(-(r_pos - r_neg)))
+
+        # regularizer loss is added as weight decay in optimization function
+        if self.decay_c > 0:
+            disentangled_loss = self.distance_cor_calc.calculate_disentangled_loss(
+                embedding_factor_lists.user_embedding_factors,
+                embedding_factor_lists.item_embedding_factors,
+                embedding_factor_lists.text_embedding_factors,
+                embedding_factor_lists.image_embedding_factors,
+            )
+
+            return loss_BPR + self.decay_c * disentangled_loss
+        return loss_BPR
diff --git a/cornac/models/dmrl/pwlearning_sampler.py b/cornac/models/dmrl/pwlearning_sampler.py
new file mode 100644
index 00000000..7bbcab96
--- /dev/null
+++ b/cornac/models/dmrl/pwlearning_sampler.py
@@ -0,0 +1,116 @@
+# Copyright 2018 The Cornac Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from typing import List
+import torch.utils.data as data
+from cornac.data.dataset import Dataset
+import numpy as np
+
+
+class PWLearningSampler(data.Dataset):
+    """
+    Sampler for pairwised based ranking loss function. This sampler will return
+    a batch of positive and negative items. Per user index there will be one
+    positive and num_neg negative items.
+    This sampler is to be used in the PyTorch DatLoader, through which loading
+    can be distributed amongst multiple workers.
+    """
+    def __init__(self, cornac_dataset: Dataset, num_neg: int):
+        self.data = cornac_dataset
+        self.num_neg = num_neg
+        # make sure we only have positive ratings no unseen interactions
+        assert np.all(self.data.uir_tuple[2] > 0)
+        self.user_array = self.data.uir_tuple[0]
+
+        self.item_array = self.data.uir_tuple[1]
+        self.unique_items = np.unique(self.item_array)
+        self.unique_users = np.unique(self.user_array)
+        self.user_item_array = np.vstack([self.user_array, self.item_array]).T
+        # make sure users are assending from 0
+        np.all(np.unique(self.user_array) == np.arange(self.unique_users.shape[0]))
+
+    def __getitems__(self, list_of_indexs: List[int]):
+        """
+        Vectorized version of __getitem__
+
+        Uses list_of_indixes to index into uir_tuple from cornac dataset and
+        thus retrieve 1 positive item per given user. Additionally random
+        samples num_neg negative items per user given in list_of_indices.
+
+        Parameters
+        ----------
+
+        list_of_indexs: List[int]
+            List of indexs to sample from the uir_tuple given in cornac dataset
+        """
+        batch_size = len(list_of_indexs)
+        users = self.user_array[list_of_indexs]
+        pos_items = self.item_array[list_of_indexs]
+
+        pos_u_i = np.vstack([users, pos_items]).T
+
+        # sample negative items per user
+        neg_item_list = []
+        for _ in range(self.num_neg):
+            neg_items = np.random.choice(self.data.csr_matrix.shape[1], batch_size)
+            # make sure we dont sample a positive item
+            candidates = self.data.csr_matrix[users, neg_items]
+            while candidates.nonzero()[0].size != 0:
+                replacement_neg_items = np.random.choice(self.data.csr_matrix.shape[1], candidates.nonzero()[0].size)
+                neg_items[candidates.nonzero()[1]] = replacement_neg_items
+                candidates = self.data.csr_matrix[users, neg_items]
+            
+            neg_item_list.append(neg_items)
+        neg_items = np.vstack(neg_item_list).T
+        return np.hstack([pos_u_i, neg_items])
+
+    def __getitem__(self, index):
+        """
+        Uses index into uir_tuple from cornac dataset and
+        thus retrieves 1 positive user-item pair. Additionally random
+        samples num_neg negative items for that user.
+
+        Parameters
+        ----------
+
+        list_of_indexs: List[int]
+            List of indexs to sample from the uir_tuple given in cornac dataset
+        """
+        # first select index tuple
+        user = self.user_array[index]
+        item = self.item_array[index]
+
+        # now construct positive case
+        pos_u_i = [user, item]
+
+        i = 0
+        neg_i = []
+        while i < self.num_neg:
+            neg_example = np.random.choice(self.data.uir_tuple[1])
+
+            idxs_of_item = np.where(self.item_array == neg_example)
+            users_who_have_rated_item = self.user_array[idxs_of_item]
+
+            if user not in users_who_have_rated_item:
+                i += 1
+                neg_i = neg_i + [neg_example]
+        # return user, item_positive, num_neg * item_neg array
+        return np.array(pos_u_i + neg_i)
+
+    def __len__(self):
+        """
+        Return length of sampler as length of uir_tuple from cornac dataset.
+        """
+        return len(self.data.uir_tuple[0])
diff --git a/cornac/models/dmrl/recom_dmrl.py b/cornac/models/dmrl/recom_dmrl.py
new file mode 100644
index 00000000..e87f5906
--- /dev/null
+++ b/cornac/models/dmrl/recom_dmrl.py
@@ -0,0 +1,550 @@
+# Copyright 2018 The Cornac Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from typing import List, Tuple
+
+import numpy as np
+
+from cornac.data.dataset import Dataset
+from cornac.data import FeatureModality, TextModality, ImageModality
+from cornac.metrics.ranking import Precision, Recall
+from cornac.models.recommender import Recommender
+
+
+class DMRL(Recommender):
+    """
+    Disentangled multimodal representation learning
+
+    Parameters
+    ----------
+    name: string, default: 'DMRL'
+        The name of the recommender model.
+
+    batch_size: int, optional, default: 32
+        The number of samples per batch to load.
+
+    learning_rate: float, optional, default: 1e-4
+        The learning rate for the optimizer.
+
+    decay_c: float, optional, default: 1
+        The decay for the disentangled loss term in the loss function.
+
+    decay_r: float, optional, default: 0.01
+        The decay for the regularization term in the loss function.
+
+    epochs: int, optional, default: 10
+        The number of epochs to train the model.
+
+    embedding_dim: int, optional, default: 100
+        The dimension of the embeddings.
+
+    bert_text_dim: int, optional, default: 384
+        The dimension of the bert text embeddings coming from the huggingface transformer model
+
+    image_dim: int, optional, default: None
+        The dimension of the image embeddings.
+
+    num_neg: int, optional, default: 4
+        The number of negative samples to use in the training per user per batch (1 positive and num_neg negatives are used)
+
+    num_factors: int, optional, default: 4
+        The number of factors to use in the model.
+
+    trainable: bool, optional, default: True
+        When False, the model is not trained and Cornac assumes that the model is already trained.
+
+    verbose: bool, optional, default: False
+        When True, the model prints out more information during training.
+
+    modalities_pre_built: bool, optional, default: True
+        When True, the model assumes that the modalities are already built and does not build them.
+
+    log_metrics: bool, optional, default: False
+        When True, the model logs metrics to tensorboard.
+
+    References
+    ----------
+    * Fan Liu, Huilin Chen,  Zhiyong Cheng, Anan Liu, Liqiang Nie, Mohan Kankanhalli. DMRL: Disentangled Multimodal Representation Learning for
+        Recommendation. https://arxiv.org/pdf/2203.05406.pdf.
+    """
+
+    def __init__(
+        self,
+        name: str = "DMRL",
+        batch_size: int = 32,
+        learning_rate: float = 1e-4,
+        decay_c: float = 1,
+        decay_r: float = 0.01,
+        epochs: int = 10,
+        embedding_dim: int = 100,
+        bert_text_dim: int = 384,
+        image_dim: int = None,
+        dropout: float = 0,
+        num_neg: int = 4,
+        num_factors: int = 4,
+        trainable: bool = True,
+        verbose: bool = False,
+        log_metrics: bool = False,
+    ):
+
+        super().__init__(name=name, trainable=trainable, verbose=verbose)
+
+        self.learning_rate = learning_rate
+        self.decay_c = decay_c
+        self.decay_r = decay_r
+        self.batch_size = batch_size
+        self.epochs = epochs
+        self.verbose = verbose
+        self.embedding_dim = embedding_dim
+        self.text_dim = bert_text_dim
+        self.image_dim = image_dim
+        self.dropout = dropout
+        self.num_neg = num_neg
+        self.num_factors = num_factors
+        self.log_metrics = log_metrics
+        if log_metrics:
+            from torch.utils.tensorboard import SummaryWriter
+
+            self.tb_writer = SummaryWriter("temp/tb_data/run_1")
+
+        if self.num_factors == 1:
+            # deactivate disentangled portion of loss if theres only 1 factor
+            self.decay_c == 0
+
+    def fit(self, train_set: Dataset, val_set=None):
+        """Fit the model to observations.
+
+        Parameters
+        ----------
+        train_set: :obj:`cornac.data.Dataset`, required
+            User-Item preference data as well as additional modalities.
+
+        val_set: :obj:`cornac.data.Dataset`, optional, default: None
+            User-Item preference data for model selection purposes (e.g., early stopping).
+        """
+        Recommender.fit(self, train_set, val_set)
+
+        if self.trainable:
+            self._fit_dmrl(train_set, val_set)
+
+        return self
+
+    def get_item_image_embedding(self, batch):
+        """
+        Get the item image embeddings from the image modality. Expect the image
+        modaility to be preencded and available as a numpy array.
+
+        Parameters
+        ----------
+
+        param batch: torch.Tensor, user inidices in first column, pos item indices in second
+            and all other columns are negative item indices
+        """
+        import torch
+
+        if not hasattr(self, "item_image"):
+            return None
+
+        shape = batch[:, 1:].shape
+        all_items = batch[:, 1:].flatten()
+
+        item_image_embedding = self.item_image.features[all_items, :].reshape(
+            (*shape, self.item_image.feature_dim)
+        )
+
+        if not isinstance(item_image_embedding, torch.Tensor):
+            item_image_embedding = torch.tensor(
+                item_image_embedding, dtype=torch.float32
+            )
+
+        return item_image_embedding
+
+    def get_item_text_embeddings(self, batch):
+        """
+        Get the item text embeddings from the BERT model. Either by encoding the
+        text on the fly or by using the preencoded text.
+
+        Parameters
+        ----------
+
+        param batch: torch.Tensor, user inidices in first column, pos item indices in second
+            and all other columns are negative item indices
+        """
+        import torch
+
+        shape = batch[:, 1:].shape
+        all_items = batch[:, 1:].flatten()
+
+        if not hasattr(self, "item_text"):
+            return None
+
+        if not self.item_text.preencoded:
+            item_text_embeddings = self.item_text.batch_encode(all_items)
+            item_text_embeddings = item_text_embeddings.reshape(
+                (*shape, self.item_text.output_dim)
+            )
+        else:
+            item_text_embeddings = self.item_text.features[all_items]
+            item_text_embeddings = item_text_embeddings.reshape(
+                (*shape, self.item_text.output_dim)
+            )
+
+        if not isinstance(item_text_embeddings, torch.Tensor):
+            item_text_embeddings = torch.tensor(
+                item_text_embeddings, dtype=torch.float32
+            )
+
+        return item_text_embeddings
+
+    def get_modality_embeddings(self, batch):
+        """
+        Get the modality embeddings for both text and image from the respectiv
+        modality instances.
+
+        Parameters
+        ----------
+
+        param batch: torch.Tensor, user inidices in first column, pos item
+        indices in second
+            and all other columns are negative item indices
+        """
+        item_text_embeddings = self.get_item_text_embeddings(batch)
+        item_image_embeddings = self.get_item_image_embedding(batch)
+
+        return item_text_embeddings, item_image_embeddings
+
+    def _fit_dmrl(self, train_set: Dataset, val_set: Dataset = None):
+        """
+        Fit the model to observations.
+
+        Parameters
+        ----------
+        train_set: User-Item preference data as well as additional modalities.
+        """
+        import torch
+        from torch.utils.data import DataLoader
+
+        from cornac.models.dmrl.dmrl import DMRLLoss, DMRLModel
+        from cornac.models.dmrl.pwlearning_sampler import PWLearningSampler
+
+        self.initialize_and_build_modalities(train_set)
+
+        self.device = "cuda" if torch.cuda.is_available() else "cpu"
+        print(f"Using device {self.device} for training")
+
+        self.sampler = PWLearningSampler(train_set, num_neg=self.num_neg)
+
+        self.model = DMRLModel(
+            self.num_users,
+            self.num_items,
+            self.embedding_dim,
+            self.text_dim,
+            self.image_dim,
+            self.dropout,
+            self.num_neg,
+            self.num_factors,
+        ).to(self.device)
+
+        loss_function = DMRLLoss(
+            decay_c=1e-3, num_factors=self.num_factors, num_neg=self.num_neg
+        )
+
+        # add hyperparams to tensorboard
+        if self.log_metrics:
+            self.tb_writer.add_hparams(
+                {
+                    "learning_rate": self.learning_rate,
+                    "decay_c": self.decay_c,
+                    "decay_r": self.decay_r,
+                    "batch_size": self.batch_size,
+                    "epochs": self.epochs,
+                    "embedding_dim": self.embedding_dim,
+                    "bert_text_dim": self.text_dim,
+                    "num_neg": self.num_neg,
+                    "num_factors": self.num_factors,
+                    "dropout": self.dropout,
+                },
+                {},
+            )
+
+        optimizer = torch.optim.AdamW(
+            self.model.parameters(),
+            lr=self.learning_rate,
+            weight_decay=self.decay_r,
+            betas=(0.9, 0.999),
+        )
+        # optimizer = torch.optim.RMSprop(self.model.parameters(), lr=self.learning_rate, weight_decay=self.decay_r)
+
+        # Create learning rate scheduler if needed
+        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0, last_epoch=-1)
+        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.25, step_size=35)
+
+        dataloader = DataLoader(
+            self.sampler,
+            batch_size=self.batch_size,
+            num_workers=0,
+            shuffle=True,
+            prefetch_factor=None,
+        )
+
+        if val_set is not None:
+            self.val_sampler = PWLearningSampler(val_set, num_neg=self.num_neg)
+            val_dataloader = DataLoader(
+                self.val_sampler,
+                batch_size=self.batch_size,
+                num_workers=0,
+                shuffle=True,
+                prefetch_factor=None,
+            )
+
+        j = 1
+        stop = False
+        # Training loop
+        for epoch in range(self.epochs):
+            if stop:
+                break
+            running_loss = 0
+            running_loss_val = 0
+            last_loss = 0
+            i = 0
+
+            batch: torch.Tensor
+            for i, batch in enumerate(dataloader):
+
+                optimizer.zero_grad()
+                item_text_embeddings, item_image_embeddings = (
+                    self.get_modality_embeddings(batch)
+                )
+
+                # move the data to the device
+                batch = batch.to(self.device)
+                if item_text_embeddings is not None:
+                    item_text_embeddings = item_text_embeddings.to(self.device)
+                if item_image_embeddings is not None:
+                    item_image_embeddings = item_image_embeddings.to(self.device)
+
+                # Forward pass
+                embedding_factor_lists, rating_scores = self.model(
+                    batch, item_text_embeddings, item_image_embeddings
+                )
+                # preds = self.model(u_batch, i_batch, text)
+                loss = loss_function(embedding_factor_lists, rating_scores)
+
+                # Backward pass and optimize
+                loss.backward()
+                # torch.nn.utils.clip_grad_value_(self.model.parameters(), 5) # use if exploding gradient becomes an issue
+                if self.log_metrics:
+                    self.model.log_gradients_and_weights()
+
+                optimizer.step()
+
+                if val_set is not None:
+                    val_batch = next(val_dataloader.__iter__())
+                    item_text_embeddings_val, item_image_embeddings_val = (
+                        self.get_modality_embeddings(val_batch)
+                    )
+
+                    # Forward pass
+                    with torch.no_grad():
+                        embedding_factor_lists_val, rating_scores_val = self.model(
+                            val_batch,
+                            item_text_embeddings_val,
+                            item_image_embeddings_val,
+                        )
+                        # preds = self.model(u_batch, i_batch, text)
+                        loss_val = loss_function(
+                            embedding_factor_lists_val, rating_scores_val
+                        )
+                        running_loss_val += loss_val.item()
+
+                # Gather data and report
+                running_loss += loss.item()
+                devider = 5
+                if i % devider == 4:
+                    last_loss = running_loss / devider  # loss per batch
+                    # last_loss = running_loss / (i + 1)
+                    print("  batch {} loss: {}".format(i + 1, last_loss))
+
+                    if self.log_metrics:
+                        # tb_x = epoch * len(dataloader) + i + 1
+                        self.tb_writer.add_scalar("Loss/train", last_loss, j)
+                        self.tb_writer.add_scalar(
+                            "Loss/val", running_loss_val / devider, j
+                        )
+                        self.tb_writer.add_scalar(
+                            "Gradient Norm/train", np.mean(self.model.grad_norms), j
+                        )
+                        self.tb_writer.add_scalar(
+                            "Param Norm/train", np.mean(self.model.param_norms), j
+                        )
+                        self.tb_writer.add_scalar(
+                            "User-Item based rating", np.mean(self.model.ui_ratings), j
+                        )
+                        self.tb_writer.add_scalar(
+                            "User-Text based rating", np.mean(self.model.ut_ratings), j
+                        )
+                        self.tb_writer.add_scalar(
+                            "User-Itm Attention", np.mean(self.model.ui_attention), j
+                        )
+                        self.tb_writer.add_scalar(
+                            "User-Text Attention", np.mean(self.model.ut_attention), j
+                        )
+                        for name, param in self.model.named_parameters():
+                            self.tb_writer.add_scalar(
+                                name + "/grad_norm",
+                                np.mean(self.model.grad_dict[name]),
+                                j,
+                            )
+                            self.tb_writer.add_histogram(
+                                name + "/grad", param.grad, global_step=epoch
+                            )
+                        self.tb_writer.add_scalar(
+                            "Learning rate", optimizer.param_groups[0]["lr"], j
+                        )
+                        self.model.reset_grad_metrics()
+                    running_loss = 0
+                    running_loss_val = 0
+
+                # if i % 999== 0:
+                # scheduler.step()
+
+                i += 1
+                j += 1
+
+            print(f"Epoch: {epoch} is done")
+            # scheduler.step()
+        print("Finished training!")
+        # self.eval_train_set_performance() # evaluate the model on the training set after training if necessary
+
+    def eval_train_set_performance(self) -> Tuple[float, float]:
+        """
+        Evaluate the models training set performance using Recall 300 metric.
+        """
+        from cornac.eval_methods.base_method import ranking_eval
+
+        print("Evaluating training set performance at k=300")
+        avg_results, _ = ranking_eval(
+            self,
+            [Recall(k=300), Precision(k=300)],
+            self.train_set,
+            self.train_set,
+            verbose=True,
+            rating_threshold=4,
+        )
+        print(f"Mean train set recall and precision: {avg_results}")
+        return avg_results
+
+    def score(self, user_index: int, item_indices = None):
+        """
+        Scores a user-item pair. If item_index is None, scores for all known
+        items.
+
+        Parameters
+        ----------
+        name: user_idx
+            The index of the user for whom to perform score prediction.
+
+        item_indices: torch.Tensor, optional, default: None
+            The index of the item for which to perform score prediction.
+            If None, scores for all known items will be returned.
+        """
+        import torch
+
+        self.model.num_neg = 0
+        self.model.eval()
+
+        encoded_image = None
+        encoded_text = None
+
+        if item_indices is None:
+            item_indices = torch.tensor(list(self.iid_map.values()), dtype=torch.long)
+
+        user_index = user_index * torch.ones(len(item_indices), dtype=torch.long)
+
+        if self.item_text.features is None:
+            self.item_text.preencode_entire_corpus()
+
+        # since the model expects as (batch size, 1 + num_neg, encoding dim) we just add one dim and repeat
+        if hasattr(self, "item_text"):
+            encoded_text: torch.Tensor = self.item_text.features[
+                item_indices, :
+            ]
+            encoded_text = encoded_text[:, None, :]
+            encoded_text = encoded_text.to(self.device)
+
+        if hasattr(self, "item_image"):
+            encoded_image = torch.tensor(
+                self.item_image.features[item_indices, :], dtype=torch.float32
+            )
+            encoded_image = encoded_image[:, None, :]
+            encoded_image = encoded_image.to(self.device)
+
+        input_tensor = torch.stack((user_index, item_indices), axis=1)
+        input_tensor = input_tensor.to(self.device)
+
+        with torch.no_grad():
+            _, ratings_sum_over_mods = self.model(
+                input_tensor, encoded_text, encoded_image
+            )
+
+        return np.array(ratings_sum_over_mods[:, 0].detach().cpu())
+
+    def initialize_and_build_modalities(self, trainset: Dataset):
+        """
+        Initializes text and image modalities for the model. Either takes in raw
+        text or image and performs pre-encoding given the transformer models in
+        TransformerTextModality and TransformerVisionModality. If preencoded
+        features are given, it uses those instead and simply wrapes them into a
+        general FeatureModality instance, as no further encoding model is
+        required.
+        """
+        from cornac.models.dmrl.transformer_text import TransformersTextModality
+        from cornac.models.dmrl.transformer_vision import TransformersVisionModality
+
+        if trainset.item_text is not None:
+            if (
+                isinstance(trainset.item_text, TextModality)
+                and trainset.item_text.corpus is not None
+            ):
+                self.item_text = TransformersTextModality(
+                    corpus=trainset.item_text.corpus,
+                    ids=trainset.item_text.ids,
+                    preencode=True,
+                )
+            elif isinstance(
+                trainset.item_text, FeatureModality
+            ):  # already have preencoded text features from outside
+                self.item_text = trainset.item_text
+                assert trainset.item_text.features is not None, "No pre-encoded features found, please use TextModality"
+            else:
+                raise ValueError("Not supported type of modality for item text")
+
+        if trainset.item_image is not None:
+            if (
+                isinstance(trainset.item_image, ImageModality)
+                and trainset.item_image.images is not None
+            ):
+                self.item_image = TransformersVisionModality(
+                    images=trainset.item_image.images,
+                    ids=trainset.item_image.ids,
+                    preencode=True,
+                )
+            elif isinstance(
+                trainset.item_image, FeatureModality
+            ):  # already have preencoded image features from outside
+                self.item_image = trainset.item_image
+                assert trainset.item_image.features is not None, "No pre-encoded features found, please use ImageModality"
+            else:
+                raise ValueError("Not supported type of modality for item image")
diff --git a/cornac/models/dmrl/requirements.txt b/cornac/models/dmrl/requirements.txt
new file mode 100644
index 00000000..764643e8
--- /dev/null
+++ b/cornac/models/dmrl/requirements.txt
@@ -0,0 +1,7 @@
+pandas
+torch
+sentence_transformers
+pytest
+dcor
+torchvision
+requests
\ No newline at end of file
diff --git a/cornac/models/dmrl/transformer_text.py b/cornac/models/dmrl/transformer_text.py
new file mode 100644
index 00000000..8d07ee80
--- /dev/null
+++ b/cornac/models/dmrl/transformer_text.py
@@ -0,0 +1,98 @@
+# Copyright 2018 The Cornac Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import os
+from typing import List
+from collections import OrderedDict
+from sentence_transformers import SentenceTransformer
+from operator import itemgetter
+
+import torch
+
+from cornac.data.modality import FeatureModality
+
+
+class TransformersTextModality(FeatureModality):
+    """
+    Transformer text modality wrapped around SentenceTransformer library.
+    https://huggingface.co/sentence-transformers.
+
+    Parameters
+    ----------
+    corpus: List[str], default = None
+        List of user/item texts that the indices are aligned with `ids`.
+    """
+
+    def __init__(
+            self,
+            corpus: List[str] = None,
+            ids: List = None,
+            preencode: bool = False,
+            model_name_or_path: str = 'paraphrase-MiniLM-L6-v2',
+            **kwargs):
+
+        super().__init__(ids=ids, **kwargs)
+        self.corpus = corpus
+        self.model = SentenceTransformer(model_name_or_path)
+        self.output_dim = self.model[-1].pooling_output_dimension
+        self.preencode = preencode
+        self.preencoded = False
+        
+        if self.preencode:
+            self.preencode_entire_corpus()
+
+    def preencode_entire_corpus(self):
+        """
+        Pre-encode the entire corpus. This is useful so that we don't have to do
+        it on the fly in training. Might take significant time to pre-encode
+        larger datasets.
+        """
+        
+        path = "temp/encoded_corpus.pt"
+        id_path = "temp/encoded_corpus_ids.pt"
+
+        if os.path.exists(path) and os.path.exists(id_path):
+            saved_ids = torch.load(id_path)
+            try:
+                if saved_ids == self.ids:
+                    self.features = torch.load(path)
+                    self.preencoded = True
+                else:
+                    assert self.preencoded is False
+            except:  # noqa: E722
+                print("The ids of the saved encoded corpus do not match the current ids. Re-encoding the corpus.")
+        
+        if not self.preencoded:
+            print("Pre-encoding the entire corpus. This might take a while.")
+            self.features = self.model.encode(self.corpus, convert_to_tensor=True)
+            self.preencoded = True
+            os.makedirs("temp", exist_ok = True)
+            torch.save(self.features, path)
+            torch.save(self.ids, id_path)
+
+    def batch_encode(self, ids: List[int]):
+        """
+        Batch encode on the fly the list of item ids
+
+        Parameters
+        ----------
+        ids: List[int]
+            List of item ids to encode.
+        """
+
+        text_batch = list(itemgetter(*ids)(self.corpus))
+        encoded = self.model.encode(text_batch, convert_to_tensor=True)
+
+        return encoded
diff --git a/cornac/models/dmrl/transformer_vision.py b/cornac/models/dmrl/transformer_vision.py
new file mode 100644
index 00000000..967b3b5f
--- /dev/null
+++ b/cornac/models/dmrl/transformer_vision.py
@@ -0,0 +1,149 @@
+# Copyright 2018 The Cornac Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from collections import OrderedDict
+from typing import List
+from cornac.data.modality import FeatureModality
+import os
+from PIL.JpegImagePlugin import JpegImageFile
+import torch
+from torchvision import transforms, models
+from torchvision.models._api import WeightsEnum
+
+
+class TransformersVisionModality(FeatureModality):
+    """
+    Transformer vision modality wrapped around the torchvision ViT Transformer.
+
+    Parameters
+    ----------
+    corpus: List[JpegImageFile], default = None
+        List of user/item texts that the indices are aligned with `ids`.
+    """
+
+    def __init__(
+        self,
+        images: List[JpegImageFile] = None,
+        ids: List = None,
+        preencode: bool = False,
+        model_weights: WeightsEnum = models.ViT_H_14_Weights.DEFAULT,
+        **kwargs
+    ):
+
+        super().__init__(ids=ids, **kwargs)
+        self.images = images
+        self.model = models.vit_h_14(weights=model_weights)
+        # suppress the classification piece
+        self.model.heads = torch.nn.Identity()
+        self.model.eval()
+
+        self.image_size = (self.model.image_size, self.model.image_size)
+        self.image_to_tensor_transformer = transforms.Compose(
+            [
+                transforms.ToTensor()
+                # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize pixel values
+            ]
+        )
+
+        self.preencode = preencode
+        self.preencoded = False
+        self.batch_size = 50
+        
+        if self.preencode:
+            self.preencode_images()
+
+    def preencode_images(self):
+        """
+        Pre-encode the entire image library. This is useful so that we don't
+        have to do it on the fly in training. Might take significant time to
+        pre-encode.
+        """
+
+        path = "temp/encoded_images.pt"
+        id_path = "temp/encoded_images_ids.pt"
+
+        if os.path.exists(path) and os.path.exists(id_path):
+            saved_ids = torch.load(id_path)
+            if saved_ids == self.ids:
+                self.features = torch.load(path)
+                self.preencoded = True
+            else:
+                assert self.preencoded is False
+                print(
+                    "The ids of the saved encoded images do not match the current ids. Re-encoding the images."
+                )
+
+        if not self.preencoded:
+            print("Pre-encoding the entire image library. This might take a while.")
+            self._encode_images()
+            self.preencoded = True
+            os.makedirs("temp", exist_ok=True)
+            torch.save(self.features, path)
+            torch.save(self.ids, id_path)
+
+    def _encode_images(self):
+        """
+        Encode all images in the library.
+        """
+        for i in range(len(self.images) // self.batch_size + 1):
+            tensor_batch = self.transform_images_to_torch_tensor(
+                self.images[i * self.batch_size : (i + 1) * self.batch_size]
+            )
+            with torch.no_grad():
+                encoded_batch = self.model(tensor_batch)
+
+            if i == 0:
+                self.features = encoded_batch
+            else:
+                self.features = torch.cat((self.features, encoded_batch), 0)
+
+    def transform_images_to_torch_tensor(
+        self, images: List[JpegImageFile]
+    ) -> torch.Tensor:
+        """
+        Transorms a list of PIL images to a torch tensor batch.
+
+        Parameters
+        ----------
+        images: List[PIL.Image]
+            List of PIL images to be transformed to torch tensor.
+        """
+        for i, img in enumerate(images):
+            if img.size != self.image_size:
+                img = img.resize(self.image_size)
+
+            tensor = self.image_to_tensor_transformer(img)
+            tensor = tensor.unsqueeze(0)
+            if i == 0:
+                tensor_batch = tensor
+            else:
+                tensor_batch = torch.cat((tensor_batch, tensor), 0)
+
+        return tensor_batch
+
+    def batch_encode(self, ids: List[int]):
+        """
+        Batch encode on the fly the photos for the list of item ids.
+
+        Parameters
+        ----------
+        ids: List[int]
+            List of item ids to encode.
+        """
+        tensor_batch = self.transform_images_to_torch_tensor(self.images[ids])
+        with torch.no_grad():
+            encoded_batch = self.model(tensor_batch)
+
+        return encoded_batch
diff --git a/docs/source/api_ref/models.rst b/docs/source/api_ref/models.rst
index 2b2c5992..4fc2f26e 100644
--- a/docs/source/api_ref/models.rst
+++ b/docs/source/api_ref/models.rst
@@ -11,6 +11,11 @@ Recommender (Generic Class)
 .. automodule:: cornac.models.recommender
    :members:
 
+Disentangled Multimodal Representation Learning for Recommendation (DMRL)
+-------------------------------------------------------------------------
+.. automodule:: cornac.models.dmrl.recom_dmrl
+   :members:
+
 Bilateral VAE for Collaborative Filtering (BiVAECF)
 ---------------------------------------------------
 .. automodule:: cornac.models.bivaecf.recom_bivaecf
diff --git a/examples/README.md b/examples/README.md
index ab5be47b..794218d6 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -54,6 +54,8 @@
 
 [cvae_example.py](cvae_example.py) - Collaborative Variational Autoencoder (CVAE) with CiteULike dataset.
 
+[dmrl_example.py](dmrl_example.py) - Disentangled Multimodal Representation Learning (DMRL) with citeulike dataset.
+
 [trirank_example.py](trirank_example.py) - TriRank with Amazon Toy and Games dataset.
 
 [efm_example.py](efm_example.py) - Explicit Factor Model (EFM) with Amazon Toy and Games dataset.
@@ -68,6 +70,8 @@
 
 [causalrec_clothing.py](causalrec_clothing.py) - CausalRec with Clothing dataset.
 
+[dmrl_clothes_example.py](dmrl_clothes_example.py) - Disentangled Multimodal Representation Learning (DMRL) with Amazon clothing dataset.
+
 [vbpr_tradesy.py](vbpr_tradesy.py) - Visual Bayesian Personalized Ranking (VBPR) with Tradesy dataset.
 
 [vmf_clothing.py](vmf_clothing.py) - Visual Matrix Factorization (VMF) with Amazon Clothing dataset.
diff --git a/examples/dmrl_clothes_example.py b/examples/dmrl_clothes_example.py
new file mode 100644
index 00000000..9099c641
--- /dev/null
+++ b/examples/dmrl_clothes_example.py
@@ -0,0 +1,59 @@
+"""
+Example for Disentangled Multimodal Recommendation, with feedback, textual and visual modality.
+This example uses preencoded visual features from cornac dataset instead of TransformersVisionModality modality.
+"""
+
+import cornac
+from cornac.data import TextModality, ImageModality
+from cornac.datasets import amazon_clothing
+from cornac.eval_methods import RatioSplit
+
+
+feedback = amazon_clothing.load_feedback()
+image_features, image_item_ids = amazon_clothing.load_visual_feature()  # BIG file
+docs, text_item_ids = amazon_clothing.load_text()
+
+
+# only treat good feedback as positive user-item pair
+new_feedback = [f for f in feedback if f[2] >= 4]
+
+text_modality = TextModality(corpus=docs, ids=text_item_ids)
+image_modality = ImageModality(features=image_features, ids=image_item_ids)
+
+ratio_split = RatioSplit(
+    data=new_feedback,
+    test_size=0.25,
+    exclude_unknowns=True,
+    verbose=True,
+    seed=123,
+    rating_threshold=4,
+    item_text=text_modality,
+    item_image=image_modality,
+)
+
+dmrl_recommender = cornac.models.dmrl.DMRL(
+    batch_size=1024,
+    epochs=60,
+    log_metrics=False,
+    learning_rate=0.001,
+    num_factors=2,
+    decay_r=2,
+    decay_c=0.1,
+    num_neg=5,
+    embedding_dim=100,
+    image_dim=4096,
+    dropout=0,
+)
+
+
+# Use Recall@300 for evaluations
+rec_300 = cornac.metrics.Recall(k=300)
+rec_900 = cornac.metrics.Recall(k=900)
+prec_30 = cornac.metrics.Precision(k=30)
+
+# Put everything together into an experiment and run it
+cornac.Experiment(
+    eval_method=ratio_split,
+    models=[dmrl_recommender],
+    metrics=[prec_30, rec_300, rec_900],
+).run()
diff --git a/examples/dmrl_example.py b/examples/dmrl_example.py
new file mode 100644
index 00000000..3b9fc067
--- /dev/null
+++ b/examples/dmrl_example.py
@@ -0,0 +1,50 @@
+"""Example for Disentangled Multimodal Recommendation, with only feedback and textual modality.
+For an example including image modality please see dmrl_clothes_example.py"""
+
+import cornac
+from cornac.data import Reader
+from cornac.datasets import citeulike
+from cornac.eval_methods import RatioSplit
+from cornac.data import TextModality
+
+# The necessary data can be loaded as follows
+docs, item_ids = citeulike.load_text()
+feedback = citeulike.load_feedback(reader=Reader(item_set=item_ids))
+
+item_text_modality = TextModality(
+    corpus=docs,
+    ids=item_ids,
+)
+
+# Define an evaluation method to split feedback into train and test sets
+ratio_split = RatioSplit(
+    data=feedback,
+    test_size=0.2,
+    exclude_unknowns=True,
+    verbose=True,
+    seed=123,
+    rating_threshold=0.5,
+    item_text=item_text_modality,
+)
+
+# Instantiate DMRL recommender
+dmrl_recommender = cornac.models.dmrl.DMRL(
+    batch_size=4096,
+    epochs=20,
+    log_metrics=False,
+    learning_rate=0.01,
+    num_factors=2,
+    decay_r=0.5,
+    decay_c=0.01,
+    num_neg=3,
+    embedding_dim=100,
+)
+
+# Use Recall@300 for evaluations
+rec_300 = cornac.metrics.Recall(k=300)
+prec_30 = cornac.metrics.Precision(k=30)
+
+# Put everything together into an experiment and run it
+cornac.Experiment(
+    eval_method=ratio_split, models=[dmrl_recommender], metrics=[prec_30, rec_300]
+).run()
diff --git a/pytest.ini b/pytest.ini
index 73ca0beb..4a940238 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -1,7 +1,8 @@
 # Configuration of py.test
 [pytest]
 norecursedirs = tests/cornac/datasets
-    
+pythonpath = .
+
 addopts=-v
         --durations=20
         --ignore=tests/cornac/utils/test_download.py
diff --git a/tests/cornac/models/__init__.py b/tests/cornac/models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/cornac/models/dmrl/__init__.py b/tests/cornac/models/dmrl/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/cornac/models/dmrl/test_distance_calc.py b/tests/cornac/models/dmrl/test_distance_calc.py
new file mode 100644
index 00000000..b0795bf6
--- /dev/null
+++ b/tests/cornac/models/dmrl/test_distance_calc.py
@@ -0,0 +1,47 @@
+"""
+Pytest tests for cornac.models.dmrl.d_cor_calc.py
+"""
+
+try:
+    import torch
+    import dcor
+    from cornac.models.dmrl.d_cor_calc import DistanceCorrelationCalculator
+
+    run_dmrl_test_funcs = True
+
+except ImportError:
+    run_dmrl_test_funcs = False
+
+
+def skip_test_in_case_of_missing_reqs(test_func):
+    test_func.__test__ = (
+        run_dmrl_test_funcs  # Mark the test function as (non-)discoverable by unittest
+    )
+    return test_func
+
+
+# first test the distance correlation calculator
+@skip_test_in_case_of_missing_reqs
+def test_distance_correlation_calculator():
+    """
+    Test the distance correlation calculator. Compare agains the library dcor.
+    """
+    num_neg = 4 + 1
+    distance_cor_calc = DistanceCorrelationCalculator(n_factors=2, num_neg=num_neg)
+    # create some fake data
+    tensor_x = torch.randn(3, num_neg, 10)
+    tensor_y = torch.randn(3, num_neg, 10)
+    assert tensor_x.shape[1] == num_neg
+    assert tensor_y.shape[1] == num_neg
+
+    cor_per_sample = distance_cor_calc.calculate_cor(tensor_x, tensor_y)
+    assert cor_per_sample.shape[0] == tensor_x.shape[1]
+
+    for sample in range(num_neg - 1):
+        # cutoff everyyhing after 5th decimal
+        assert round(cor_per_sample[sample].item(), 2) == round(
+            dcor.distance_correlation(
+                tensor_x[:, sample, :], tensor_y[:, sample, :]
+            ).item(),
+            2,
+        )
diff --git a/tests/cornac/models/dmrl/test_pwlearning_sampler.py b/tests/cornac/models/dmrl/test_pwlearning_sampler.py
new file mode 100644
index 00000000..cd58d311
--- /dev/null
+++ b/tests/cornac/models/dmrl/test_pwlearning_sampler.py
@@ -0,0 +1,100 @@
+# add a checker to make sure all requirements needed in the imports here are really present.
+# if they are missing skip the respective test
+# If a user wants to un these please run: pip install -r cornac/models/dmrl/requirements.txt
+
+
+import unittest
+
+from cornac.data.dataset import Dataset
+from cornac.data.reader import Reader
+from cornac.datasets import citeulike
+
+try:
+    from torch.utils.data import DataLoader
+    from cornac.models.dmrl.pwlearning_sampler import PWLearningSampler
+
+    run_dmrl_test_funcs = True
+
+except ImportError:
+    run_dmrl_test_funcs = False
+
+
+def skip_test_in_case_of_missing_reqs(test_func):
+    test_func.__test__ = (
+        run_dmrl_test_funcs  # Mark the test function as (non-)discoverable by unittest
+    )
+    return test_func
+
+
+class TestPWLearningSampler(unittest.TestCase):
+    """
+    Tests that the PW Sampler returns the desired batch of data.
+    """
+
+    @skip_test_in_case_of_missing_reqs
+    def setUp(self):
+        self.num_neg = 4
+        _, item_ids = citeulike.load_text()
+        feedback = citeulike.load_feedback(reader=Reader(item_set=item_ids))
+        cornac_dataset = Dataset.build(data=feedback)
+        self.sampler = PWLearningSampler(cornac_dataset, num_neg=self.num_neg)
+
+    @skip_test_in_case_of_missing_reqs
+    def test_get_batch_multiprocessed(self):
+        """
+        Tests multiprocessed loading via Torch Datalodaer
+        """
+        batch_size = 32
+        dataloader = DataLoader(
+            self.sampler,
+            batch_size=batch_size,
+            num_workers=3,
+            shuffle=True,
+            prefetch_factor=3,
+        )
+        generator_data_loader = iter(dataloader)
+        batch = next(generator_data_loader)
+        assert batch.shape == (batch_size, 2 + self.num_neg)
+
+    @skip_test_in_case_of_missing_reqs
+    def test_correctness(self):
+        """
+        Tests the correctness of the PWLearningSampler by asserting that the
+        correct positive and negative user-item pairs are returned.
+        """
+        batch_size = 32
+        dataloader = DataLoader(
+            self.sampler,
+            batch_size=batch_size,
+            num_workers=0,
+            shuffle=True,
+            prefetch_factor=None,
+        )
+        generator_data_loader = iter(dataloader)
+        batch = next(generator_data_loader).numpy()
+        assert batch.shape == (batch_size, 2 + self.num_neg)
+        for i in range(batch_size):
+            user = batch[i, 0]
+            pos_item = batch[i, 1]
+            neg_items = batch[i, 2:]
+            assert pos_item in self.sampler.data.csr_matrix[user].nonzero()[1]
+            for neg_item in neg_items:
+                assert neg_item not in self.sampler.data.csr_matrix[user].nonzero()[1]
+
+    @skip_test_in_case_of_missing_reqs
+    def test_full_epoch_sampler(self):
+        """
+        Tests speed of loader for full epoch
+        """
+        batch_size = 32
+        dataloader = DataLoader(
+            self.sampler,
+            batch_size=batch_size,
+            num_workers=0,
+            shuffle=True,
+            prefetch_factor=None,
+        )
+        i = 0
+        for _ in dataloader:
+            i += 1
+        assert i == self.sampler.user_array.shape[0] // batch_size + 1
diff --git a/tests/cornac/models/dmrl/test_transformertext.py b/tests/cornac/models/dmrl/test_transformertext.py
new file mode 100644
index 00000000..7637ae4c
--- /dev/null
+++ b/tests/cornac/models/dmrl/test_transformertext.py
@@ -0,0 +1,53 @@
+# add a checker to make sure all requirements needed in the imports here are really present.
+# if they are missing skip the respective test
+# If a user wants to un these please run: pip install -r cornac/models/dmrl/requirements.txt
+
+import unittest
+
+try:
+    import torch
+    from sentence_transformers import util
+    from cornac.models.dmrl.transformer_text import TransformersTextModality
+
+    run_dmrl_test_funcs = True
+
+except ImportError:
+    run_dmrl_test_funcs = False
+
+
+def skip_test_in_case_of_missing_reqs(test_func):
+    test_func.__test__ = (
+        run_dmrl_test_funcs  # Mark the test function as (non-)discoverable by unittest
+    )
+    return test_func
+
+
+class TestTransformersTextModality(unittest.TestCase):
+    @skip_test_in_case_of_missing_reqs
+    def setUp(self):
+        self.corpus = ["I like you very much.", "I like you so much"]
+        self.ids = [0, 1]
+        self.modality = TransformersTextModality(
+            corpus=self.corpus, ids=self.ids, preencode=True
+        )
+
+    @skip_test_in_case_of_missing_reqs
+    def test_batch_encode(self):
+        encoded_batch = self.modality.batch_encode(self.ids)
+
+        assert encoded_batch.shape[0] == 2
+        assert isinstance(encoded_batch, torch.Tensor)
+
+    @skip_test_in_case_of_missing_reqs
+    def test_preencode_entire_corpus(self):
+        self.modality.preencode_entire_corpus()
+        assert self.modality.preencoded
+
+        assert torch.load("temp/encoded_corpus_ids.pt") == self.ids
+        assert torch.load("temp/encoded_corpus.pt").shape[0] == len(self.corpus)
+
+    @skip_test_in_case_of_missing_reqs
+    def test_batch_encode_similarity(self):
+        encoded_batch = self.modality.batch_encode(self.ids)
+        similarity = util.cos_sim(encoded_batch[0], encoded_batch[1])
+        assert similarity > 0.9
diff --git a/tests/cornac/models/dmrl/test_transformervision.py b/tests/cornac/models/dmrl/test_transformervision.py
new file mode 100644
index 00000000..9452fc46
--- /dev/null
+++ b/tests/cornac/models/dmrl/test_transformervision.py
@@ -0,0 +1,112 @@
+"""
+Tests for the TransformersVisionModality class. In order to run this test please
+insert url_to_beach1, url_to_beach2, url_to_cat in the get_photos method. Use
+your favorite beach and cat photos and check the similarity scores.
+"""
+
+import unittest
+
+# add a checker to make sure all requirements needed in the imports here are really present.
+# if they are missing skip the respective test
+# If a user wants to un these please run: pip install -r cornac/models/dmrl/requirements.txt
+try:
+    import torch
+    import requests
+    from PIL import Image
+    from sentence_transformers import util
+
+    from cornac.models.dmrl.transformer_vision import TransformersVisionModality
+
+    run_dmrl_test_funcs = True
+
+except ImportError:
+    run_dmrl_test_funcs = False
+
+
+def skip_test_in_case_of_missing_reqs(test_func):
+    test_func.__test__ = (
+        run_dmrl_test_funcs  # Mark the test function as (non-)discoverable by unittest
+    )
+    return test_func
+
+
+# Please insert valid urls here to two beach photos and one cat photo
+beach_urls = ["url_to_beach1", "url_to_beach2"]
+cat_url = "url_to_cat"
+
+
+class TestTransformersVisionModality(unittest.TestCase):
+
+    def get_photos(self):
+        for i, url in enumerate(beach_urls):
+            r = requests.get(url)
+            with open(f"beach{i}.jpg", "wb") as f:
+                f.write(r.content)
+
+        r = requests.get(cat_url)
+        with open("cat.jpg", "wb") as f:
+            f.write(r.content)
+
+    @skip_test_in_case_of_missing_reqs
+    def setUp(self):
+        self.get_photos()
+        beach1 = Image.open("beach0.jpg")
+        beach2 = Image.open("beach1.jpg")
+        cat = Image.open("cat.jpg")
+        self.images = [beach1, beach2, cat]
+        self.ids = [0, 1]
+        self.modality = TransformersVisionModality(
+            images=self.images, ids=self.ids, preencode=True
+        )
+
+    @skip_test_in_case_of_missing_reqs
+    @unittest.skipIf(
+        "url_to_beach1" in beach_urls,
+        "Please insert a valid url to download 2 beach and one cat photo",
+    )
+    def test_transform_image_to_tensor(self):
+        """
+        Tests that an image is transformed correctly to a tensor
+        """
+        image_tensor_batch = self.modality.transform_images_to_torch_tensor(self.images)
+        assert isinstance(image_tensor_batch, torch.Tensor)
+        assert image_tensor_batch.shape[0:2] == torch.Size(
+            (3, 3)
+        )  # 3 images with 3 channels each
+        assert image_tensor_batch.shape[2:] == torch.Size(self.modality.image_size)
+
+    @skip_test_in_case_of_missing_reqs
+    @unittest.skipIf(
+        "url_to_beach1" in beach_urls,
+        "Please insert a valid url to download 2 beach and one cat photo",
+    )
+    def test_encode_all_images(self):
+        """
+        Tests that all images are encoded
+        """
+        self.modality._encode_images()
+        assert isinstance(self.modality.features, torch.Tensor)
+        assert self.modality.features.shape[0] == len(self.images)
+        assert self.modality.features.shape[1] == 1000
+
+    @skip_test_in_case_of_missing_reqs
+    @unittest.skipIf(
+        "url_to_beach1" in beach_urls,
+        "Please insert a valid url to download 2 beach and one cat photo",
+    )
+    def test_encoding_quality(self):
+        """
+        Test similiarity in latent space between some images
+        """
+        self.modality._encode_images()
+        beach1_beach2_similarity = util.cos_sim(
+            self.modality.features[0], self.modality.features[1]
+        )
+        assert beach1_beach2_similarity > 0.7
+
+        beach_cat_similarity = util.cos_sim(
+            self.modality.features[0], self.modality.features[2]
+        )
+        assert beach_cat_similarity < 0.1
+
+        assert beach1_beach2_similarity > beach_cat_similarity
-- 
GitLab