Skip to content
Snippets Groups Projects
Unverified Commit 296d2d90 authored by Max Beckers's avatar Max Beckers Committed by GitHub
Browse files

Add DMRL model (#597)

parent 44a8fc9c
No related branches found
No related tags found
No related merge requests found
Showing
with 1680 additions and 1 deletion
...@@ -148,6 +148,7 @@ The recommender models supported by Cornac are listed below. Why don't you join ...@@ -148,6 +148,7 @@ The recommender models supported by Cornac are listed below. Why don't you join
| Year | Model and paper | Model type | Require-ments | Examples | | Year | Model and paper | Model type | Require-ments | Examples |
| :---: | --- | :---: | :---: | :---: | | :---: | --- | :---: | :---: | :---: |
| 2024 | [Hypergraphs with Attention on Reviews (HypAR)](cornac/models/hypar), [paper](https://doi.org/10.1007/978-3-031-56027-9_14)| Hybrid / Sentiment / Explainable | [reqs](cornac/models/hypar/requirements_cu116.txt) | [exp](https://github.com/PreferredAI/HypAR) | 2024 | [Hypergraphs with Attention on Reviews (HypAR)](cornac/models/hypar), [paper](https://doi.org/10.1007/978-3-031-56027-9_14)| Hybrid / Sentiment / Explainable | [reqs](cornac/models/hypar/requirements_cu116.txt) | [exp](https://github.com/PreferredAI/HypAR)
| 2022 | [Disentangled Multimodal Representation Learning for Recommendation (DMRL)](cornac/models/dmrl), [paper](https://arxiv.org/pdf/2203.05406.pdf) | Content-Based / Text & Image | [reqs](cornac/models/dmrl/requirements.txt) | [exp](examples/dmrl_example.py)
| 2021 | [Bilateral Variational Autoencoder for Collaborative Filtering (BiVAECF)](cornac/models/bivaecf), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441759) | Collaborative Filtering / Content-Based | [reqs](cornac/models/bivaecf/requirements.txt) | [exp](https://github.com/PreferredAI/bi-vae) | 2021 | [Bilateral Variational Autoencoder for Collaborative Filtering (BiVAECF)](cornac/models/bivaecf), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441759) | Collaborative Filtering / Content-Based | [reqs](cornac/models/bivaecf/requirements.txt) | [exp](https://github.com/PreferredAI/bi-vae)
| | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [paper](https://arxiv.org/abs/2107.02390) | Content-Based / Image | [reqs](cornac/models/causalrec/requirements.txt) | [exp](examples/causalrec_clothing.py) | | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [paper](https://arxiv.org/abs/2107.02390) | Content-Based / Image | [reqs](cornac/models/causalrec/requirements.txt) | [exp](examples/causalrec_clothing.py)
| | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | Explainable | N/A | [exp](https://github.com/PreferredAI/ComparER) | | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | Explainable | N/A | [exp](https://github.com/PreferredAI/ComparER)
......
...@@ -38,6 +38,7 @@ from .conv_mf import ConvMF ...@@ -38,6 +38,7 @@ from .conv_mf import ConvMF
from .ctr import CTR from .ctr import CTR
from .cvae import CVAE from .cvae import CVAE
from .cvaecf import CVAECF from .cvaecf import CVAECF
from .dmrl import DMRL
from .dnntsp import DNNTSP from .dnntsp import DNNTSP
from .ease import EASE from .ease import EASE
from .efm import EFM from .efm import EFM
......
from .recom_dmrl import DMRL
\ No newline at end of file
# Copyright 2018 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import torch
class DistanceCorrelationCalculator:
"""
Calculates the disentangled loss for DMRL model.
Please see https://arxiv.org/pdf/2203.05406.pdf for more details.
"""
def __init__(self, n_factors, num_neg) -> None:
self.n_factors = n_factors
self.num_neg = num_neg
def calculate_cov(self, X, Y):
"""
Computes the distance covariance between X and Y.
:param X: A 3D torch tensor.
:param Y: A 3D torch tensor.
:return: A 1D torch tensor of len 1+num_neg.
"""
# first create centered distance matrices
X = self.cent_dist(X)
Y = self.cent_dist(Y)
# batch_size is dim 1, as dim 0 is one positive and num_neg negative samples
n_samples = X.shape[1]
# then calculate the covariance as a 1D array of length 1+num_neg
cov = torch.sqrt(torch.max(torch.sum(X * Y, dim=(1, 2)) / (n_samples * n_samples), torch.tensor(1e-5)))
return cov
def calculate_var(self, X):
"""
Computes the distance variance of X.
:param X: A 3D torch tensor.
:return: A 1D torch tensor of len 1+mum_neg.
"""
return self.calculate_cov(X, X)
def calculate_cor(self, X, Y):
"""
Computes the distance correlation between X and Y.
:param X: A 3D torch tensor.
:param Y: A 3D torch tensor.
:return: A 1D torch tensor of len 1+mum_neg.
"""
return self.calculate_cov(X, Y) / torch.sqrt(torch.max(self.calculate_var(X) * self.calculate_var(Y), torch.tensor(0.0)))
def cent_dist(self, X):
"""
Computes the pairwise euclidean distance between rows of X and centers
each cell of the distance matrix with row mean, column mean, and grand mean.
"""
# put the samples from dim 1 into dim 0
X = torch.transpose(X, dim0=0, dim1=1)
# Now use pythagoras to calculate the distance matrix
first_part = torch.sum(torch.square(X), dim=-1, keepdims=True)
middle_part = torch.matmul(X, torch.transpose(X, dim0=1, dim1=2))
last_part = torch.transpose(first_part, dim0=1, dim1=2)
D = torch.sqrt(torch.max(first_part - 2 * middle_part + last_part, torch.tensor(1e-5)))
# dim0 is the negative samples, dim1 is batch_size, dim2 is the kth factor of the embedding_dim
row_mean = torch.mean(D, dim=2, keepdim=True)
column_mean = torch.mean(D, dim=1, keepdim=True)
global_mean = torch.mean(D, dim=(1, 2), keepdim=True)
D = D - row_mean - column_mean + global_mean
return D
def calculate_disentangled_loss(
self,
item_embedding_factors: torch.Tensor,
user_embedding_factors: torch.Tensor,
text_embedding_factors: torch.Tensor,
image_embedding_factors: torch.Tensor):
"""
Calculates the disentangled loss for the given factors.
:param item_embedding_factors: A list of 3D torch tensors.
:param user_embedding_factors: A list of 3D torch tensors.
:param text_embedding_factors: A list of 3D torch tensors.
:return: A 1D torch tensor of len 1+mum_neg.
"""
cor_loss = torch.tensor([0.0] * (1 + self.num_neg))
for i in range(0, self.n_factors - 2):
for j in range(i + 1, self.n_factors - 1):
cor_loss += self.calculate_cor(item_embedding_factors[i], item_embedding_factors[j])
cor_loss += self.calculate_cor(user_embedding_factors[i], user_embedding_factors[j])
if text_embedding_factors[i].numel() > 0:
cor_loss += self.calculate_cor(text_embedding_factors[i], text_embedding_factors[j])
if image_embedding_factors[i].numel() > 0:
cor_loss += self.calculate_cor(image_embedding_factors[i], image_embedding_factors[j])
cor_loss = cor_loss / ((self.n_factors + 1.0) * self.n_factors / 2)
# two options, we can either return the sum over the 1 positive and num_neg negative samples.
# or we can return only the loss of the one positive sample, as they did in the paper
# return torch.sum(cor_loss)
return cor_loss[0]
# Copyright 2018 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from typing import List, Tuple
import torch
import torch.nn as nn
from cornac.models.dmrl.d_cor_calc import DistanceCorrelationCalculator
from dataclasses import dataclass
from cornac.utils.common import get_rng
from cornac.utils.init_utils import normal, xavier_normal, xavier_uniform
@dataclass
class EmbeddingFactorLists:
"""
A dataclass for holding the embedding factors for each modality.
"""
user_embedding_factors: List[torch.Tensor]
item_embedding_factors: List[torch.Tensor]
text_embedding_factors: List[torch.Tensor] = None
image_embedding_factors: List[torch.Tensor] = None
class DMRLModel(nn.Module):
"""
The actual Disentangled Multi-Modal Recommendation Model neural network.
"""
def __init__(
self,
num_users: int,
num_items: int,
embedding_dim: int,
text_dim: int,
image_dim: int,
dropout: float,
num_neg: int,
num_factors: int,
seed: int = 123,
):
super(DMRLModel, self).__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.num_factors = num_factors
self.num_neg = num_neg
self.embedding_dim = embedding_dim
self.num_modalities = 1 + bool(text_dim) + bool(image_dim)
self.dropout = dropout
self.grad_norms = []
self.param_norms = []
self.ui_ratings = []
self.ut_ratings = []
self.ui_attention = []
self.ut_attention = []
rng = get_rng(123)
if text_dim:
self.text_module = torch.nn.Sequential(
torch.nn.Dropout(p=self.dropout),
torch.nn.Linear(text_dim, 150),
torch.nn.LeakyReLU(),
torch.nn.Dropout(p=self.dropout),
torch.nn.Linear(150, embedding_dim),
torch.nn.LeakyReLU(),
)
self.text_module[1].weight.data = torch.from_numpy(
xavier_normal([150, text_dim], random_state=rng)
) # , std=0.02))
self.text_module[4].weight.data
if image_dim:
self.image_module = torch.nn.Sequential(
torch.nn.Dropout(p=self.dropout),
torch.nn.Linear(image_dim, 150),
torch.nn.LeakyReLU(),
torch.nn.Dropout(p=self.dropout),
torch.nn.Linear(150, embedding_dim),
torch.nn.LeakyReLU(),
)
self.user_embedding = torch.nn.Embedding(num_users, embedding_dim)
self.item_embedding = torch.nn.Embedding(num_items, embedding_dim)
self.user_embedding.weight.data = torch.from_numpy(
xavier_normal([num_users, embedding_dim], random_state=rng)
) # , std=0.02))
self.item_embedding.weight.data = torch.from_numpy(
xavier_normal([num_items, embedding_dim], random_state=rng)
) # , std=0.02))
self.factor_size = self.embedding_dim // self.num_factors
self.attention_layer = torch.nn.Sequential(
torch.nn.Dropout(p=self.dropout),
torch.nn.Linear(
(self.num_modalities + 1) * self.factor_size, self.num_modalities
),
torch.nn.Tanh(),
torch.nn.Dropout(p=self.dropout),
torch.nn.Linear(self.num_modalities, self.num_modalities, bias=False),
torch.nn.Softmax(dim=-1),
)
self.attention_layer[1].weight.data = torch.from_numpy(
xavier_normal(
[self.num_modalities, (self.num_modalities + 1) * self.factor_size],
random_state=rng,
)
) # , std=0.02))
self.attention_layer[4].weight.data = torch.from_numpy(
xavier_normal([self.num_modalities, self.num_modalities], random_state=rng)
) # , std=0.02))
self.grad_dict = {i[0]: [] for i in self.named_parameters()}
def forward(
self, batch: torch.Tensor, text: torch.Tensor, image: torch.Tensor
) -> Tuple[EmbeddingFactorLists, torch.Tensor]:
"""
Forward pass of the model.
Parameters:
-----------
batch: torch.Tensor
A batch of data. The first column contains the user indices, the
rest of the columns contain the item indices (one pos and num_neg negatives)
text: torch.Tensor
The text data for the items in the batch (encoded)
image: torch.Tensor
The image data for the items in the batch (encoded)
"""
text_embedding_factors = [
torch.tensor([]).to(self.device) for _ in range(self.num_factors)
]
image_embedding_factors = [
torch.tensor([]).to(self.device) for _ in range(self.num_factors)
]
users = batch[:, 0]
items = batch[:, 1:]
# handle text
if text is not None:
text_embedding = self.text_module(
torch.nn.functional.normalize(text, dim=-1)
)
text_embedding_factors = torch.split(
text_embedding, self.embedding_dim // self.num_factors, dim=-1
)
# handle image
if image is not None:
image_embedding = self.image_module(
torch.nn.functional.normalize(image, dim=-1)
)
image_embedding_factors = torch.split(
image_embedding, self.embedding_dim // self.num_factors, dim=-1
)
# handle users
user_embedding = self.user_embedding(users)
# we have to get users into shape batch, 1+num_neg, embedding_dim
# therefore we repeat the users across the 1 pos and num_neg items
user_embedding_inflated = user_embedding.unsqueeze(1).repeat(
1, items.shape[1], 1
)
user_embedding_factors = torch.split(
user_embedding_inflated, self.embedding_dim // self.num_factors, dim=-1
)
# handle items
item_embedding = self.item_embedding(items)
item_embedding_factors = torch.split(
item_embedding, self.embedding_dim // self.num_factors, dim=-1
)
embedding_factor_lists = EmbeddingFactorLists(
user_embedding_factors,
item_embedding_factors,
text_embedding_factors,
image_embedding_factors,
)
# attentionLayer: implemented per factor k
batch_size = users.shape[0]
ratings_sum_over_mods = torch.zeros((batch_size, 1 + self.num_neg)).to(
self.device
)
for i in range(self.num_factors):
concatted_features = torch.concatenate(
[
user_embedding_factors[i],
item_embedding_factors[i],
text_embedding_factors[i],
image_embedding_factors[i],
],
axis=2,
)
attention = self.attention_layer(
torch.nn.functional.normalize(concatted_features, dim=-1)
)
r_ui = attention[:, :, 0] * torch.nn.Softplus()(
torch.sum(
user_embedding_factors[i] * item_embedding_factors[i], axis=-1
)
)
# log rating
self.ui_ratings.append(torch.norm(r_ui.detach().flatten()).cpu())
factor_rating = r_ui
if text is not None:
r_ut = attention[:, :, 1] * torch.nn.Softplus()(
torch.sum(
user_embedding_factors[i] * text_embedding_factors[i], axis=-1
)
)
factor_rating = factor_rating + r_ut
# log rating
self.ut_ratings.append(torch.norm(r_ut.detach().flatten()).cpu())
if image is not None:
r_ui = attention[:, :, 1] * torch.nn.Softplus()(
torch.sum(
user_embedding_factors[i] * image_embedding_factors[i], axis=-1
)
)
factor_rating = factor_rating + r_ui
self.ui_ratings.append(torch.norm(r_ui.detach().flatten()).cpu())
# sum up over modalities and running sum over factors
ratings_sum_over_mods = ratings_sum_over_mods + factor_rating
return embedding_factor_lists, ratings_sum_over_mods
def log_gradients_and_weights(self):
"""
Stores most recent gradient norms in a list.
"""
for i in self.named_parameters():
self.grad_dict[i[0]].append(torch.norm(i[1].grad.detach().flatten()).item())
total_norm_grad = torch.norm(
torch.cat([p.grad.detach().flatten() for p in self.parameters()])
)
self.grad_norms.append(total_norm_grad.item())
total_norm_param = torch.norm(
torch.cat([p.detach().flatten() for p in self.parameters()])
)
self.param_norms.append(total_norm_param.item())
def reset_grad_metrics(self):
"""
Reset the gradient metrics.
"""
self.grad_norms = []
self.param_norms = []
self.grad_dict = {i[0]: [] for i in self.named_parameters()}
self.ui_ratings = []
self.ut_ratings = []
self.ut_attention = []
self.ut_attention = []
class DMRLLoss(nn.Module):
"""
The disentangled multi-modal recommendation model loss function. It's a
combination of pairwise based ranking loss and disentangled loss. For
details see DMRL paper.
"""
def __init__(self, decay_c, num_factors, num_neg):
super(DMRLLoss, self).__init__()
self.decay_c = decay_c
self.distance_cor_calc = DistanceCorrelationCalculator(
n_factors=num_factors, num_neg=num_neg
)
def forward(
self, embedding_factor_lists: EmbeddingFactorLists, rating_scores: torch.tensor
) -> torch.tensor:
"""
Calculates the loss for the batch of data.
"""
r_pos = rating_scores[:, 0]
# from the num_neg many negative sampled items, we want to find the one
# with the largest score to have one negative sample per user in our
# batch
r_neg = torch.max(rating_scores[:, 1:], dim=1).values
# define the ranking loss for pairwise-based ranking approach
loss_BPR = torch.sum(torch.nn.Softplus()(-(r_pos - r_neg)))
# regularizer loss is added as weight decay in optimization function
if self.decay_c > 0:
disentangled_loss = self.distance_cor_calc.calculate_disentangled_loss(
embedding_factor_lists.user_embedding_factors,
embedding_factor_lists.item_embedding_factors,
embedding_factor_lists.text_embedding_factors,
embedding_factor_lists.image_embedding_factors,
)
return loss_BPR + self.decay_c * disentangled_loss
return loss_BPR
# Copyright 2018 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from typing import List
import torch.utils.data as data
from cornac.data.dataset import Dataset
import numpy as np
class PWLearningSampler(data.Dataset):
"""
Sampler for pairwised based ranking loss function. This sampler will return
a batch of positive and negative items. Per user index there will be one
positive and num_neg negative items.
This sampler is to be used in the PyTorch DatLoader, through which loading
can be distributed amongst multiple workers.
"""
def __init__(self, cornac_dataset: Dataset, num_neg: int):
self.data = cornac_dataset
self.num_neg = num_neg
# make sure we only have positive ratings no unseen interactions
assert np.all(self.data.uir_tuple[2] > 0)
self.user_array = self.data.uir_tuple[0]
self.item_array = self.data.uir_tuple[1]
self.unique_items = np.unique(self.item_array)
self.unique_users = np.unique(self.user_array)
self.user_item_array = np.vstack([self.user_array, self.item_array]).T
# make sure users are assending from 0
np.all(np.unique(self.user_array) == np.arange(self.unique_users.shape[0]))
def __getitems__(self, list_of_indexs: List[int]):
"""
Vectorized version of __getitem__
Uses list_of_indixes to index into uir_tuple from cornac dataset and
thus retrieve 1 positive item per given user. Additionally random
samples num_neg negative items per user given in list_of_indices.
Parameters
----------
list_of_indexs: List[int]
List of indexs to sample from the uir_tuple given in cornac dataset
"""
batch_size = len(list_of_indexs)
users = self.user_array[list_of_indexs]
pos_items = self.item_array[list_of_indexs]
pos_u_i = np.vstack([users, pos_items]).T
# sample negative items per user
neg_item_list = []
for _ in range(self.num_neg):
neg_items = np.random.choice(self.data.csr_matrix.shape[1], batch_size)
# make sure we dont sample a positive item
candidates = self.data.csr_matrix[users, neg_items]
while candidates.nonzero()[0].size != 0:
replacement_neg_items = np.random.choice(self.data.csr_matrix.shape[1], candidates.nonzero()[0].size)
neg_items[candidates.nonzero()[1]] = replacement_neg_items
candidates = self.data.csr_matrix[users, neg_items]
neg_item_list.append(neg_items)
neg_items = np.vstack(neg_item_list).T
return np.hstack([pos_u_i, neg_items])
def __getitem__(self, index):
"""
Uses index into uir_tuple from cornac dataset and
thus retrieves 1 positive user-item pair. Additionally random
samples num_neg negative items for that user.
Parameters
----------
list_of_indexs: List[int]
List of indexs to sample from the uir_tuple given in cornac dataset
"""
# first select index tuple
user = self.user_array[index]
item = self.item_array[index]
# now construct positive case
pos_u_i = [user, item]
i = 0
neg_i = []
while i < self.num_neg:
neg_example = np.random.choice(self.data.uir_tuple[1])
idxs_of_item = np.where(self.item_array == neg_example)
users_who_have_rated_item = self.user_array[idxs_of_item]
if user not in users_who_have_rated_item:
i += 1
neg_i = neg_i + [neg_example]
# return user, item_positive, num_neg * item_neg array
return np.array(pos_u_i + neg_i)
def __len__(self):
"""
Return length of sampler as length of uir_tuple from cornac dataset.
"""
return len(self.data.uir_tuple[0])
This diff is collapsed.
pandas
torch
sentence_transformers
pytest
dcor
torchvision
requests
\ No newline at end of file
# Copyright 2018 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
from typing import List
from collections import OrderedDict
from sentence_transformers import SentenceTransformer
from operator import itemgetter
import torch
from cornac.data.modality import FeatureModality
class TransformersTextModality(FeatureModality):
"""
Transformer text modality wrapped around SentenceTransformer library.
https://huggingface.co/sentence-transformers.
Parameters
----------
corpus: List[str], default = None
List of user/item texts that the indices are aligned with `ids`.
"""
def __init__(
self,
corpus: List[str] = None,
ids: List = None,
preencode: bool = False,
model_name_or_path: str = 'paraphrase-MiniLM-L6-v2',
**kwargs):
super().__init__(ids=ids, **kwargs)
self.corpus = corpus
self.model = SentenceTransformer(model_name_or_path)
self.output_dim = self.model[-1].pooling_output_dimension
self.preencode = preencode
self.preencoded = False
if self.preencode:
self.preencode_entire_corpus()
def preencode_entire_corpus(self):
"""
Pre-encode the entire corpus. This is useful so that we don't have to do
it on the fly in training. Might take significant time to pre-encode
larger datasets.
"""
path = "temp/encoded_corpus.pt"
id_path = "temp/encoded_corpus_ids.pt"
if os.path.exists(path) and os.path.exists(id_path):
saved_ids = torch.load(id_path)
try:
if saved_ids == self.ids:
self.features = torch.load(path)
self.preencoded = True
else:
assert self.preencoded is False
except: # noqa: E722
print("The ids of the saved encoded corpus do not match the current ids. Re-encoding the corpus.")
if not self.preencoded:
print("Pre-encoding the entire corpus. This might take a while.")
self.features = self.model.encode(self.corpus, convert_to_tensor=True)
self.preencoded = True
os.makedirs("temp", exist_ok = True)
torch.save(self.features, path)
torch.save(self.ids, id_path)
def batch_encode(self, ids: List[int]):
"""
Batch encode on the fly the list of item ids
Parameters
----------
ids: List[int]
List of item ids to encode.
"""
text_batch = list(itemgetter(*ids)(self.corpus))
encoded = self.model.encode(text_batch, convert_to_tensor=True)
return encoded
# Copyright 2018 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from collections import OrderedDict
from typing import List
from cornac.data.modality import FeatureModality
import os
from PIL.JpegImagePlugin import JpegImageFile
import torch
from torchvision import transforms, models
from torchvision.models._api import WeightsEnum
class TransformersVisionModality(FeatureModality):
"""
Transformer vision modality wrapped around the torchvision ViT Transformer.
Parameters
----------
corpus: List[JpegImageFile], default = None
List of user/item texts that the indices are aligned with `ids`.
"""
def __init__(
self,
images: List[JpegImageFile] = None,
ids: List = None,
preencode: bool = False,
model_weights: WeightsEnum = models.ViT_H_14_Weights.DEFAULT,
**kwargs
):
super().__init__(ids=ids, **kwargs)
self.images = images
self.model = models.vit_h_14(weights=model_weights)
# suppress the classification piece
self.model.heads = torch.nn.Identity()
self.model.eval()
self.image_size = (self.model.image_size, self.model.image_size)
self.image_to_tensor_transformer = transforms.Compose(
[
transforms.ToTensor()
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize pixel values
]
)
self.preencode = preencode
self.preencoded = False
self.batch_size = 50
if self.preencode:
self.preencode_images()
def preencode_images(self):
"""
Pre-encode the entire image library. This is useful so that we don't
have to do it on the fly in training. Might take significant time to
pre-encode.
"""
path = "temp/encoded_images.pt"
id_path = "temp/encoded_images_ids.pt"
if os.path.exists(path) and os.path.exists(id_path):
saved_ids = torch.load(id_path)
if saved_ids == self.ids:
self.features = torch.load(path)
self.preencoded = True
else:
assert self.preencoded is False
print(
"The ids of the saved encoded images do not match the current ids. Re-encoding the images."
)
if not self.preencoded:
print("Pre-encoding the entire image library. This might take a while.")
self._encode_images()
self.preencoded = True
os.makedirs("temp", exist_ok=True)
torch.save(self.features, path)
torch.save(self.ids, id_path)
def _encode_images(self):
"""
Encode all images in the library.
"""
for i in range(len(self.images) // self.batch_size + 1):
tensor_batch = self.transform_images_to_torch_tensor(
self.images[i * self.batch_size : (i + 1) * self.batch_size]
)
with torch.no_grad():
encoded_batch = self.model(tensor_batch)
if i == 0:
self.features = encoded_batch
else:
self.features = torch.cat((self.features, encoded_batch), 0)
def transform_images_to_torch_tensor(
self, images: List[JpegImageFile]
) -> torch.Tensor:
"""
Transorms a list of PIL images to a torch tensor batch.
Parameters
----------
images: List[PIL.Image]
List of PIL images to be transformed to torch tensor.
"""
for i, img in enumerate(images):
if img.size != self.image_size:
img = img.resize(self.image_size)
tensor = self.image_to_tensor_transformer(img)
tensor = tensor.unsqueeze(0)
if i == 0:
tensor_batch = tensor
else:
tensor_batch = torch.cat((tensor_batch, tensor), 0)
return tensor_batch
def batch_encode(self, ids: List[int]):
"""
Batch encode on the fly the photos for the list of item ids.
Parameters
----------
ids: List[int]
List of item ids to encode.
"""
tensor_batch = self.transform_images_to_torch_tensor(self.images[ids])
with torch.no_grad():
encoded_batch = self.model(tensor_batch)
return encoded_batch
...@@ -11,6 +11,11 @@ Recommender (Generic Class) ...@@ -11,6 +11,11 @@ Recommender (Generic Class)
.. automodule:: cornac.models.recommender .. automodule:: cornac.models.recommender
:members: :members:
Disentangled Multimodal Representation Learning for Recommendation (DMRL)
-------------------------------------------------------------------------
.. automodule:: cornac.models.dmrl.recom_dmrl
:members:
Bilateral VAE for Collaborative Filtering (BiVAECF) Bilateral VAE for Collaborative Filtering (BiVAECF)
--------------------------------------------------- ---------------------------------------------------
.. automodule:: cornac.models.bivaecf.recom_bivaecf .. automodule:: cornac.models.bivaecf.recom_bivaecf
......
...@@ -54,6 +54,8 @@ ...@@ -54,6 +54,8 @@
[cvae_example.py](cvae_example.py) - Collaborative Variational Autoencoder (CVAE) with CiteULike dataset. [cvae_example.py](cvae_example.py) - Collaborative Variational Autoencoder (CVAE) with CiteULike dataset.
[dmrl_example.py](dmrl_example.py) - Disentangled Multimodal Representation Learning (DMRL) with citeulike dataset.
[trirank_example.py](trirank_example.py) - TriRank with Amazon Toy and Games dataset. [trirank_example.py](trirank_example.py) - TriRank with Amazon Toy and Games dataset.
[efm_example.py](efm_example.py) - Explicit Factor Model (EFM) with Amazon Toy and Games dataset. [efm_example.py](efm_example.py) - Explicit Factor Model (EFM) with Amazon Toy and Games dataset.
...@@ -68,6 +70,8 @@ ...@@ -68,6 +70,8 @@
[causalrec_clothing.py](causalrec_clothing.py) - CausalRec with Clothing dataset. [causalrec_clothing.py](causalrec_clothing.py) - CausalRec with Clothing dataset.
[dmrl_clothes_example.py](dmrl_clothes_example.py) - Disentangled Multimodal Representation Learning (DMRL) with Amazon clothing dataset.
[vbpr_tradesy.py](vbpr_tradesy.py) - Visual Bayesian Personalized Ranking (VBPR) with Tradesy dataset. [vbpr_tradesy.py](vbpr_tradesy.py) - Visual Bayesian Personalized Ranking (VBPR) with Tradesy dataset.
[vmf_clothing.py](vmf_clothing.py) - Visual Matrix Factorization (VMF) with Amazon Clothing dataset. [vmf_clothing.py](vmf_clothing.py) - Visual Matrix Factorization (VMF) with Amazon Clothing dataset.
......
"""
Example for Disentangled Multimodal Recommendation, with feedback, textual and visual modality.
This example uses preencoded visual features from cornac dataset instead of TransformersVisionModality modality.
"""
import cornac
from cornac.data import TextModality, ImageModality
from cornac.datasets import amazon_clothing
from cornac.eval_methods import RatioSplit
feedback = amazon_clothing.load_feedback()
image_features, image_item_ids = amazon_clothing.load_visual_feature() # BIG file
docs, text_item_ids = amazon_clothing.load_text()
# only treat good feedback as positive user-item pair
new_feedback = [f for f in feedback if f[2] >= 4]
text_modality = TextModality(corpus=docs, ids=text_item_ids)
image_modality = ImageModality(features=image_features, ids=image_item_ids)
ratio_split = RatioSplit(
data=new_feedback,
test_size=0.25,
exclude_unknowns=True,
verbose=True,
seed=123,
rating_threshold=4,
item_text=text_modality,
item_image=image_modality,
)
dmrl_recommender = cornac.models.dmrl.DMRL(
batch_size=1024,
epochs=60,
log_metrics=False,
learning_rate=0.001,
num_factors=2,
decay_r=2,
decay_c=0.1,
num_neg=5,
embedding_dim=100,
image_dim=4096,
dropout=0,
)
# Use Recall@300 for evaluations
rec_300 = cornac.metrics.Recall(k=300)
rec_900 = cornac.metrics.Recall(k=900)
prec_30 = cornac.metrics.Precision(k=30)
# Put everything together into an experiment and run it
cornac.Experiment(
eval_method=ratio_split,
models=[dmrl_recommender],
metrics=[prec_30, rec_300, rec_900],
).run()
"""Example for Disentangled Multimodal Recommendation, with only feedback and textual modality.
For an example including image modality please see dmrl_clothes_example.py"""
import cornac
from cornac.data import Reader
from cornac.datasets import citeulike
from cornac.eval_methods import RatioSplit
from cornac.data import TextModality
# The necessary data can be loaded as follows
docs, item_ids = citeulike.load_text()
feedback = citeulike.load_feedback(reader=Reader(item_set=item_ids))
item_text_modality = TextModality(
corpus=docs,
ids=item_ids,
)
# Define an evaluation method to split feedback into train and test sets
ratio_split = RatioSplit(
data=feedback,
test_size=0.2,
exclude_unknowns=True,
verbose=True,
seed=123,
rating_threshold=0.5,
item_text=item_text_modality,
)
# Instantiate DMRL recommender
dmrl_recommender = cornac.models.dmrl.DMRL(
batch_size=4096,
epochs=20,
log_metrics=False,
learning_rate=0.01,
num_factors=2,
decay_r=0.5,
decay_c=0.01,
num_neg=3,
embedding_dim=100,
)
# Use Recall@300 for evaluations
rec_300 = cornac.metrics.Recall(k=300)
prec_30 = cornac.metrics.Precision(k=30)
# Put everything together into an experiment and run it
cornac.Experiment(
eval_method=ratio_split, models=[dmrl_recommender], metrics=[prec_30, rec_300]
).run()
# Configuration of py.test # Configuration of py.test
[pytest] [pytest]
norecursedirs = tests/cornac/datasets norecursedirs = tests/cornac/datasets
pythonpath = .
addopts=-v addopts=-v
--durations=20 --durations=20
--ignore=tests/cornac/utils/test_download.py --ignore=tests/cornac/utils/test_download.py
......
"""
Pytest tests for cornac.models.dmrl.d_cor_calc.py
"""
try:
import torch
import dcor
from cornac.models.dmrl.d_cor_calc import DistanceCorrelationCalculator
run_dmrl_test_funcs = True
except ImportError:
run_dmrl_test_funcs = False
def skip_test_in_case_of_missing_reqs(test_func):
test_func.__test__ = (
run_dmrl_test_funcs # Mark the test function as (non-)discoverable by unittest
)
return test_func
# first test the distance correlation calculator
@skip_test_in_case_of_missing_reqs
def test_distance_correlation_calculator():
"""
Test the distance correlation calculator. Compare agains the library dcor.
"""
num_neg = 4 + 1
distance_cor_calc = DistanceCorrelationCalculator(n_factors=2, num_neg=num_neg)
# create some fake data
tensor_x = torch.randn(3, num_neg, 10)
tensor_y = torch.randn(3, num_neg, 10)
assert tensor_x.shape[1] == num_neg
assert tensor_y.shape[1] == num_neg
cor_per_sample = distance_cor_calc.calculate_cor(tensor_x, tensor_y)
assert cor_per_sample.shape[0] == tensor_x.shape[1]
for sample in range(num_neg - 1):
# cutoff everyyhing after 5th decimal
assert round(cor_per_sample[sample].item(), 2) == round(
dcor.distance_correlation(
tensor_x[:, sample, :], tensor_y[:, sample, :]
).item(),
2,
)
# add a checker to make sure all requirements needed in the imports here are really present.
# if they are missing skip the respective test
# If a user wants to un these please run: pip install -r cornac/models/dmrl/requirements.txt
import unittest
from cornac.data.dataset import Dataset
from cornac.data.reader import Reader
from cornac.datasets import citeulike
try:
from torch.utils.data import DataLoader
from cornac.models.dmrl.pwlearning_sampler import PWLearningSampler
run_dmrl_test_funcs = True
except ImportError:
run_dmrl_test_funcs = False
def skip_test_in_case_of_missing_reqs(test_func):
test_func.__test__ = (
run_dmrl_test_funcs # Mark the test function as (non-)discoverable by unittest
)
return test_func
class TestPWLearningSampler(unittest.TestCase):
"""
Tests that the PW Sampler returns the desired batch of data.
"""
@skip_test_in_case_of_missing_reqs
def setUp(self):
self.num_neg = 4
_, item_ids = citeulike.load_text()
feedback = citeulike.load_feedback(reader=Reader(item_set=item_ids))
cornac_dataset = Dataset.build(data=feedback)
self.sampler = PWLearningSampler(cornac_dataset, num_neg=self.num_neg)
@skip_test_in_case_of_missing_reqs
def test_get_batch_multiprocessed(self):
"""
Tests multiprocessed loading via Torch Datalodaer
"""
batch_size = 32
dataloader = DataLoader(
self.sampler,
batch_size=batch_size,
num_workers=3,
shuffle=True,
prefetch_factor=3,
)
generator_data_loader = iter(dataloader)
batch = next(generator_data_loader)
assert batch.shape == (batch_size, 2 + self.num_neg)
@skip_test_in_case_of_missing_reqs
def test_correctness(self):
"""
Tests the correctness of the PWLearningSampler by asserting that the
correct positive and negative user-item pairs are returned.
"""
batch_size = 32
dataloader = DataLoader(
self.sampler,
batch_size=batch_size,
num_workers=0,
shuffle=True,
prefetch_factor=None,
)
generator_data_loader = iter(dataloader)
batch = next(generator_data_loader).numpy()
assert batch.shape == (batch_size, 2 + self.num_neg)
for i in range(batch_size):
user = batch[i, 0]
pos_item = batch[i, 1]
neg_items = batch[i, 2:]
assert pos_item in self.sampler.data.csr_matrix[user].nonzero()[1]
for neg_item in neg_items:
assert neg_item not in self.sampler.data.csr_matrix[user].nonzero()[1]
@skip_test_in_case_of_missing_reqs
def test_full_epoch_sampler(self):
"""
Tests speed of loader for full epoch
"""
batch_size = 32
dataloader = DataLoader(
self.sampler,
batch_size=batch_size,
num_workers=0,
shuffle=True,
prefetch_factor=None,
)
i = 0
for _ in dataloader:
i += 1
assert i == self.sampler.user_array.shape[0] // batch_size + 1
# add a checker to make sure all requirements needed in the imports here are really present.
# if they are missing skip the respective test
# If a user wants to un these please run: pip install -r cornac/models/dmrl/requirements.txt
import unittest
try:
import torch
from sentence_transformers import util
from cornac.models.dmrl.transformer_text import TransformersTextModality
run_dmrl_test_funcs = True
except ImportError:
run_dmrl_test_funcs = False
def skip_test_in_case_of_missing_reqs(test_func):
test_func.__test__ = (
run_dmrl_test_funcs # Mark the test function as (non-)discoverable by unittest
)
return test_func
class TestTransformersTextModality(unittest.TestCase):
@skip_test_in_case_of_missing_reqs
def setUp(self):
self.corpus = ["I like you very much.", "I like you so much"]
self.ids = [0, 1]
self.modality = TransformersTextModality(
corpus=self.corpus, ids=self.ids, preencode=True
)
@skip_test_in_case_of_missing_reqs
def test_batch_encode(self):
encoded_batch = self.modality.batch_encode(self.ids)
assert encoded_batch.shape[0] == 2
assert isinstance(encoded_batch, torch.Tensor)
@skip_test_in_case_of_missing_reqs
def test_preencode_entire_corpus(self):
self.modality.preencode_entire_corpus()
assert self.modality.preencoded
assert torch.load("temp/encoded_corpus_ids.pt") == self.ids
assert torch.load("temp/encoded_corpus.pt").shape[0] == len(self.corpus)
@skip_test_in_case_of_missing_reqs
def test_batch_encode_similarity(self):
encoded_batch = self.modality.batch_encode(self.ids)
similarity = util.cos_sim(encoded_batch[0], encoded_batch[1])
assert similarity > 0.9
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment