diff --git a/.gitignore b/.gitignore index 01271755282894ddfe89bd70350a4da9587fbea7..a936ed943afff0c223478891fa6f81c6e7ea408b 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,11 @@ coverage.xml # Sphinx documentation docs/_build/ +# Dataset stuff +**.pickle +**.csv +**.lock +**.zip # Environment env diff --git a/README.md b/README.md index 58779f5dcb2ebae1e3d54489fd870a3e70ceed58..93406bd99933953ad2dd025751f89d17a4e331f5 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,7 @@ The recommender models supported by Cornac are listed below. Why don't you join | Year | Model and paper | Model type | Require-ments | Examples | | :---: | --- | :---: | :---: | :---: | +| 2024 | [Hypergraphs with Attention on Reviews (HypAR)](cornac/models/hypar), [paper](https://doi.org/10.1007/978-3-031-56027-9_14)| Hybrid / Sentiment / Explainable | [reqs](cornac/models/hypar/requirements_cu116.txt) | [exp](https://github.com/PreferredAI/HypAR) | 2021 | [Bilateral Variational Autoencoder for Collaborative Filtering (BiVAECF)](cornac/models/bivaecf), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441759) | Collaborative Filtering / Content-Based | [reqs](cornac/models/bivaecf/requirements.txt) | [exp](https://github.com/PreferredAI/bi-vae) | | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [paper](https://arxiv.org/abs/2107.02390) | Content-Based / Image | [reqs](cornac/models/causalrec/requirements.txt) | [exp](examples/causalrec_clothing.py) | | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | Explainable | N/A | [exp](https://github.com/PreferredAI/ComparER) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index 8c1b2b7ea630b9b8445d7f94790f5bfda81bd964..f4256b81b89c77ff67531a67a2aefd0a08288fe8 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -49,6 +49,7 @@ from .gru4rec import GRU4Rec from .hft import HFT from .hpf import HPF from .hrdr import HRDR +from .hypar import HypAR from .ibpr import IBPR from .knn import ItemKNN from .knn import UserKNN diff --git a/cornac/models/hypar/__init__.py b/cornac/models/hypar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e29a12743d9e0a2b35b6e0b800666318e75d696a --- /dev/null +++ b/cornac/models/hypar/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from .recom_hypar import HypAR \ No newline at end of file diff --git a/cornac/models/hypar/dgl_utils.py b/cornac/models/hypar/dgl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d02e790a601b20f7b35e9b578263b08997d448a --- /dev/null +++ b/cornac/models/hypar/dgl_utils.py @@ -0,0 +1,431 @@ +import re +from collections import OrderedDict, Counter, defaultdict +from typing import Mapping +from functools import lru_cache + +import dgl.dataloading +import torch +from dgl.dataloading.negative_sampler import _BaseNegativeSampler +import dgl.backend as F + +class HypAREdgeSampler(dgl.dataloading.EdgePredictionSampler): + def __init__(self, sampler, exclude=None, reverse_eids=None, + reverse_etypes=None, negative_sampler=None, prefetch_labels=None): + super().__init__(sampler, exclude, reverse_eids, reverse_etypes, negative_sampler, + prefetch_labels) + + def sample(self, g, seed_edges): # pylint: disable=arguments-differ + """Samples a list of blocks, as well as a subgraph containing the sampled + edges from the original graph. + + If :attr:`negative_sampler` is given, also returns another graph containing the + negative pairs as edges. + """ + if isinstance(seed_edges, Mapping): + seed_edges = {g.to_canonical_etype(k): v for k, v in seed_edges.items()} + exclude = self.exclude + pair_graph = g.edge_subgraph( + seed_edges, relabel_nodes=False, output_device=self.output_device) + eids = pair_graph.edata[dgl.EID] + + if self.negative_sampler is not None: + neg_graph = self._build_neg_graph(g, seed_edges) + pair_graph, neg_graph = dgl.compact_graphs([pair_graph, neg_graph]) + else: + pair_graph = dgl.compact_graphs(pair_graph) + + pair_graph.edata[dgl.EID] = eids + seed_nodes = pair_graph.ndata[dgl.NID] + + exclude_eids = dgl.dataloading.find_exclude_eids( + g, seed_edges, exclude, self.reverse_eids, self.reverse_etypes, + self.output_device) + + input_nodes, _, (pos_aos, neg_aos), blocks = self.sampler.sample(g, seed_nodes, exclude_eids, seed_edges) + pair_graph.edata['pos'] = pos_aos.to(pair_graph.device) + pair_graph.edata['neg'] = neg_aos.to(pair_graph.device) + + if self.negative_sampler is None: + return self.assign_lazy_features((input_nodes, pair_graph, blocks)) + else: + return self.assign_lazy_features((input_nodes, pair_graph, neg_graph, blocks)) + + +class HypARBlockSampler(dgl.dataloading.NeighborSampler): + """ + Given nodes, samples reviews and creates a batched review-graph of all sampled reviews. + Parameters + ---------- + node_review_graph: DGLHeteroGraph + A heterogeneous graph with edges from reviews to nodes (users/items) with relation-type part_of. + review_graphs: dict[DGLGraph] + A dictionary with sid to a graph representing a review based on sentiment. + fanouts: int + Number of reviews to sample per node. + kwargs: dict + Arguments to pass to NeighborSampler. Read DGL docs for options. +""" + + def __init__(self, node_review_graph, review_graphs, aggregator, sid_aos, aos_list, n_neg, ui_graph, + compact=True, fanout=5, **kwargs): + fanouts = [fanout] + + super().__init__(fanouts, **kwargs) + self.node_review_graph = node_review_graph + self.review_graphs = review_graphs + self.aggregator = aggregator + self.sid_aos = sid_aos + self.aos_list = torch.LongTensor(aos_list) + ac = Counter([a for aos in sid_aos for a in aos.numpy()]) + self.aos_probabilities = torch.log(torch.FloatTensor([ac.get(a) for a in sorted(ac)]) + 1) + self.n_neg = n_neg + self.ui_graph = ui_graph + self.compact = compact + self.n_ui_graph = self._nu_graph() + self.exclude_sids = self._create_exclude_sids(self.node_review_graph) + + def _create_exclude_sids(self, n_r_graph): + """ + Create a list of sids to exclude based on the node_review_graph. + Parameters + ---------- + n_r_graph: node_review graph + + Returns + ------- + list + """ + + exclude_sids = [] + for sid in sorted(n_r_graph.nodes('review')): + neighbors = n_r_graph.successors(sid) + es = [] + for n in neighbors: + es.append(n_r_graph.predecessors(n)) + + if len(es) > 0: + exclude_sids.append(torch.cat(es)) + else: + exclude_sids.append(torch.LongTensor([])) + return exclude_sids + + def _nu_graph(self): + """ + Create graph mapping user/items to node ids. Used for preference where users and items are seperated, while + for reviews they are combined or just seen as a node. + + Returns + ------- + DGLHeteroGraph + """ + + # Get number of user, item and nodes + n_nodes = self.node_review_graph.num_nodes('node') + n_users = self.ui_graph.num_nodes('user') + n_items = self.ui_graph.num_nodes('item') + + # Get all nodes. Nodes are user/item + nodes = self.node_review_graph.nodes('node') + device = nodes.device + nodes = nodes.cpu() + + # Create mapping + data = { + ('user', 'un', 'node'): (torch.arange(n_users, dtype=torch.int64), nodes[nodes >= n_items]), + ('item', 'in', 'node'): (torch.arange(n_items, dtype=torch.int64), nodes[nodes < n_items]) + } + + return dgl.heterograph(data, num_nodes_dict={'user': n_users, 'item': n_items, 'node': n_nodes}).to(device) + + def sample(self, g, seed_nodes, exclude_eids=None, seed_edges=None): + # If exclude eids, find the equivalent eid of the node_review_graph. + nrg_exclude_eids = None + lgcn_exclude_eids = None + + # If exclude ids, find the equivalent. + if exclude_eids is not None: + # Find sid of the exclude eids. + u, v = g.find_edges(exclude_eids) + sid = g.edata['sid'][exclude_eids].to(u.device) + + # Find exclude eids based on sid and source nodes in g. + nrg_exclude_eids = self.node_review_graph.edge_ids(sid, u, etype='part_of') + + # Find exclude eids based on sid and source nodes in g. + lgcn_exclude_eids = dgl.dataloading.find_exclude_eids( + self.ui_graph, {'user_item': seed_edges}, 'reverse_types', None, {'user_item': 'item_user', 'item_user': 'user_item'}, + self.output_device) + mask = torch.ones((len(self.sid_aos))) + mask[sid] = 0 + + # Based on seed_nodes, find reviews to represent the nodes. + input_nodes, output_nodes, blocks = super().sample(self.node_review_graph, {'node': seed_nodes}, + nrg_exclude_eids) + block = blocks[0] + + block = block['part_of'] + blocks[0] = block + + # If all nodes are removed, add random blocks/random reviews. + # Will not occur during inference. + if torch.any(block.in_degrees(block.dstnodes()) == 0): + for index in torch.where(block.in_degrees(block.dstnodes()) == 0)[0]: + perm = torch.randperm(block.num_src_nodes()) + block.add_edges(block.srcnodes()[perm[:self.fanouts[0]]], + index.repeat(min(max(1, self.fanouts[0]), block.num_src_nodes()))) + + blocks2 = [] + seed_nodes = output_nodes + + # LightGCN Sampling + for i in range(4): + if i == 0: + # Use node to user/item graph to sample first. + frontier = self.n_ui_graph.sample_neighbors( + seed_nodes, -1, edge_dir=self.edge_dir, prob=self.prob, + replace=self.replace, output_device=self.output_device, + exclude_edges=None) + else: + frontier = self.ui_graph.sample_neighbors( + seed_nodes, -1, edge_dir=self.edge_dir, prob=self.prob, + replace=self.replace, output_device=self.output_device, + exclude_edges=lgcn_exclude_eids) + + # Sample reviews based on the user/item graph. + eid = frontier.edata[dgl.EID] + block = dgl.to_block(frontier, seed_nodes) + block.edata[dgl.EID] = eid + seed_nodes = block.srcdata[dgl.NID] + blocks2.insert(0, block) + + pos_aos = [] + neg_aos = [] + # Find aspect/opinion sentiment based on the sampled reviews. + for sid in g.edata['sid'][exclude_eids].cpu().numpy(): + aosid = self.sid_aos[sid] + aosid = aosid[torch.randperm(len(aosid))[0]] + pos_aos.append(aosid) # Add positive sample. + + probability = torch.ones(len(self.aos_probabilities)) + + # Exclude self and other aspects/opinions mentioned by the user or item. + probability[aosid] = 0 + exclude_sids = torch.cat([self.sid_aos[i] for i in self.exclude_sids[sid]]) + probability[exclude_sids] = 0 + + # Add negative samples based on probability (allow duplicates). + neg_aos.append(torch.multinomial(probability, self.n_neg, replacement=True)) + + # Transform to tensors. + pos_aos = torch.LongTensor(pos_aos) + neg_aos = torch.stack(neg_aos) + + # Based on sid id, get actual aos. + pos_aos, neg_aos = self.aos_list[pos_aos], self.aos_list[neg_aos] + + return input_nodes, output_nodes, [pos_aos, neg_aos], [blocks, blocks2, mask] + + +class GlobalUniformItemSampler(_BaseNegativeSampler): + def __init__(self, k, n_items, probabilities=None): + super(_BaseNegativeSampler, self).__init__() + self.k = k + self.n_items = n_items + self.probabilities = probabilities + + def _generate(self, g, eids, canonical_etype): + _, _, vtype = canonical_etype + shape = F.shape(eids) + dtype = F.dtype(eids) + ctx = F.context(eids) + src, _ = g.find_edges(eids, etype=canonical_etype) + src = F.repeat(src, self.k, 0) + if self.probabilities is not None: + dst = torch.multinomial(self.probabilities, self.k, replacement=True).reshape(1, self.k) + else: + dst = F.randint((1, self.k), dtype, ctx, 0, self.n_items) + dst = F.repeat(dst, shape[0], 0).reshape(-1) + return src, dst + + +def stem_fn(x): + from gensim.parsing import stem_text + + # Remove special characters and numbers. Multiple dashes, single quotes, and equal signs, and similar special chars. + return stem_text(re.sub(r'--+.*|-+$|\+\+|\'.+|=+.*$|-\d.*', '', x)) + + +def stem(sentiment): + ao_preprocess_fn = stem_fn + + # Set seed for reproducibility + import random + random.seed(42) + + # Map id to new word + a_id_new = {i: ao_preprocess_fn(e) for e, i in sentiment.aspect_id_map.items()} + o_id_new = {i: ao_preprocess_fn(e) for e, i in sentiment.opinion_id_map.items()} + + # Assign new ids to words, mapping from word to id + a_id = {e: i for i, e in enumerate(sorted(set(a_id_new.values())))} + o_id = {e: i for i, e in enumerate(sorted(set(o_id_new.values())))} + + # Map old id to new id + a_o_n = {i: a_id[e] for i, e in a_id_new.items()} + o_o_n = {i: o_id[e] for i, e in o_id_new.items()} + + # Assign new ids to sentiment + sents = OrderedDict() + for i, aos in sentiment.sentiment.items(): + sents[i] = [(a_o_n[a], o_o_n[o], s) for a, o, s in aos] + + return sents, a_o_n, o_o_n + + +@lru_cache() +def generate_mappings(sentiment, match, get_ao_mappings=False, get_sent_edge_mappings=False): + # Initialize all variables + aos_user = defaultdict(list) + aos_item = defaultdict(list) + aos_sent = defaultdict(list) + user_aos = defaultdict(list) + item_aos = defaultdict(list) + sent_aos = defaultdict(list) + user_sent_edge_map = dict() + item_sent_edge_map = dict() + + # Get new sentiments and mappings from old to new id. + sent, a_mapping, o_mapping = stem(sentiment) + + # Iterate over all sentiment triples and create the corresponding mapping for users and items. + edge_id = -1 + for uid, isid in sentiment.user_sentiment.items(): + for iid, sid in isid.items(): + # Assign edge id mapping for user and item. + user_sent_edge_map[(sid, uid)] = (edge_id := edge_id + 1) # assign and increment + item_sent_edge_map[(sid, iid)] = (edge_id := edge_id + 1) + for a, o, s in sent[sid]: + if match == 'aos': + element = (a, o, s) + elif match == 'a': + element = a + elif match == 'as': + element = (a, s) + elif match == 'ao': + element = (a, o) + else: + raise NotImplementedError + + aos_user[element].append(uid) + aos_item[element].append(iid) + aos_sent[element].append(sid) + user_aos[uid].append(element) + item_aos[iid].append(element) + sent_aos[sid].append(element) + + return_data = [aos_user, aos_item, aos_sent, user_aos, item_aos, sent_aos] + + if get_ao_mappings: + return_data.extend([a_mapping, o_mapping]) + + if get_sent_edge_mappings: + return_data.extend([user_sent_edge_map, item_sent_edge_map]) + + return tuple(return_data) + + +def create_heterogeneous_graph(train_set, bipartite=True): + """ + Create a graph with users, items, aspects and opinions. + Parameters + ---------- + train_set : Dataset + bipartite: if false have a different edge type per rating; otherwise, only use interacted. + + Returns + ------- + DGLGraph + A graph with edata type, label and an initialized attention of 1/k. + int + Num nodes in graph. + int + Number of items in dataset. + int + Number of relations in dataset. + """ + import dgl + import torch + + edge_types = { + 'mentions': [], + 'described_as': [], + 'has_opinion': [], + 'co-occur': [], + } + + rating_types = set() + for indices in list(zip(*train_set.matrix.nonzero())): + rating_types.add(train_set.matrix[indices]) + + if not bipartite: + train_types = [] + for rt in rating_types: + edge_types[str(rt)] = [] + train_types.append(str(rt)) + else: + train_types = ['interacted'] + edge_types['interacted'] = [] + + sentiment_modality = train_set.sentiment + n_users = len(train_set.uid_map) + n_items = len(train_set.iid_map) + n_aspects = len(sentiment_modality.aspect_id_map) + n_opinions = len(sentiment_modality.opinion_id_map) + n_nodes = n_users + n_items + n_aspects + n_opinions + + # Create all the edges: (item, described_as, aspect), (item, has_opinion, opinion), (user, mentions, aspect), + # (aspect, cooccur, opinion), and (user, 'rating', item). Note rating is on a scale. + for org_uid, isid in sentiment_modality.user_sentiment.items(): + uid = org_uid + n_items + for iid, sid in isid.items(): + for aid, oid, _ in sentiment_modality.sentiment[sid]: + aid += n_items + n_users + oid += n_items + n_users + n_aspects + + edge_types['mentions'].append([uid, aid]) + edge_types['mentions'].append([uid, oid]) + edge_types['described_as'].append([iid, aid]) + edge_types['described_as'].append([iid, oid]) + edge_types['co-occur'].append([aid, oid]) + + if not bipartite: + edge_types[str(train_set.matrix[(org_uid, iid)])].append([uid, iid]) + else: + edge_types['interacted'].append([uid, iid]) + + # Create reverse edges. + reverse = {} + for etype, edges in edge_types.items(): + reverse['r_' + etype] = [[t, h] for h, t in edges] + + # edge_types.update(reverse) + n_relations = len(edge_types) + edges = [[h, t] for k in sorted(edge_types) for h, t in edge_types.get(k)] + edges_t = torch.LongTensor(edges).unique(dim=0).T + + et_id_map = {et: i for i, et in enumerate(sorted(edge_types))} + + g = dgl.graph((torch.cat([edges_t[0], edges_t[1]]), torch.cat([edges_t[1], edges_t[0]])), num_nodes=n_nodes) + inverse_et = {tuple(v): k for k, l in edge_types.items() for v in l} + et = torch.LongTensor([et_id_map[inverse_et[tuple(v)]] for v in edges_t.T.tolist()]) + + # Return 0 if not a rating type, else if using actual ratings return values else return 1 (bipartite). + value_fn = lambda etype: 0 if etype not in train_types else (float(etype) if etype != 'interacted' else 1) + labels = torch.FloatTensor([value_fn(inverse_et[tuple(v)]) for v in edges_t.T.tolist()]) + + g.edata['type'] = torch.cat([et, et + n_relations]) + + g.edata['label'] = torch.cat([labels, labels]) + g.edata['a'] = dgl.ops.edge_softmax(g, torch.ones_like(g.edata['label'])) + + return g, n_nodes, n_items, n_relations * 2 \ No newline at end of file diff --git a/cornac/models/hypar/hypar.py b/cornac/models/hypar/hypar.py new file mode 100644 index 0000000000000000000000000000000000000000..cc38dbe8f8a5b49a5aa69a3cdb20bfdad4fc2055 --- /dev/null +++ b/cornac/models/hypar/hypar.py @@ -0,0 +1,965 @@ +import dgl.utils +import torch +from dgl.ops import edge_softmax +from torch import nn +import dgl.function as fn +import dgl.sparse as dglsp + + +class AOSPredictionLayer(nn.Module): + """ + Ranking layer for AOS prediction. + + Parameters + ---------- + aos_predictor : str + Type of AOS predictor. Can be 'non-linear' or 'transr'. + in_dim1: int + Dimension of the first input. I.e., user/item + in_dim2: + Dimension of the second input. I.e., aspect/opinion + hidden_dims: + List of hidden dimensions, for multiple MLP layers. + n_relations: + Number of relations, i.e. sentiments. + loss: str + Loss function to be used. Can be 'bpr' or 'transr'. + """ + + def __init__(self, aos_predictor, in_dim1, in_dim2, hidden_dims, n_relations, loss='bpr'): + # Initialize variables + super().__init__() + self.loss = loss + assert loss in ['bpr', 'transr'], f'Invalid loss: {loss}' + dims = [in_dim1*2] + hidden_dims + max_i = len(dims) + r_dim = hidden_dims[-1] + + # Either have nonlinear mlp transformation or use tranr like similarity + if aos_predictor == 'non-linear': + self.mlp_ao = nn.ModuleList(nn.Sequential( + *[nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.LeakyReLU()) for i in range(max_i - 1)] + ) for _ in range(n_relations)) + dims = [in_dim2*2] + hidden_dims + self.mlp_ui = nn.Sequential( + *[nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.LeakyReLU()) for i in range(max_i - 1)] + ) + self.r = nn.Parameter(torch.zeros((n_relations, r_dim))) + elif aos_predictor == 'transr': + self.w_aor = nn.Parameter(torch.zeros((n_relations, in_dim1*2, r_dim))) + self.w_uir = nn.Parameter(torch.zeros((n_relations, in_dim2*2, r_dim))) + self.r = nn.Parameter(torch.zeros((n_relations, r_dim))) + nn.init.xavier_normal_(self.w_aor); nn.init.xavier_normal_(self.w_uir); nn.init.xavier_normal_(self.r) + else: + raise NotImplementedError + + self._aos_predictor = aos_predictor + self._n_relations = n_relations + self._out_dim = hidden_dims[-1] + + def forward(self, u_emb, i_emb, a_emb, o_emb, s): + """ + Calculates the AOS prediction + Parameters + ---------- + u_emb: torch.Tensor + User embedding + i_emb: torch.Tensor + Item embedding + a_emb: torch.Tensor + Aspect embedding + o_emb: torch.Tensor + Opinion embedding + s: torch.Tensor + Sentiment label + + Returns + ------- + torch.Tensor + Score of ui/aos ranking. + """ + + # Concatenate user and item embeddings + ui_in = torch.cat([u_emb, i_emb], dim=-1) + ao_in = torch.cat([a_emb, o_emb], dim=-1) + + # Get size + if len(ao_in.size()) == 3: + b, n, d = ao_in.size() + else: + b, d = ao_in.size() + n = 1 + + # Reshape + s = s.reshape(b, n) + ao_in = ao_in.reshape(b, n, d) + + # Transform using either non-linear mlp or transr + if self._aos_predictor == 'non-linear': + ui_emb = self.mlp_ui(ui_in) + aos_emb = torch.empty((len(s), n, self._out_dim), device=ui_emb.device) + for r in range(self._n_relations): + mask = s == r + aos_emb[mask] = self.mlp_ao[r](ao_in[mask]) + ui_emb = ui_emb.unsqueeze(1) + elif self._aos_predictor == 'transr': + ui_emb = torch.empty((b, n, self._out_dim), device=u_emb.device) + aos_emb = torch.empty((b, n, self._out_dim), device=u_emb.device) + for r in range(self._n_relations): + mask = s == r + ui_emb[mask] = torch.repeat_interleave(ui_in, mask.sum(-1), dim=0) @ self.w_uir[r] + self.r[r] + aos_emb[mask] = ao_in[mask] @ self.w_aor[r] + else: + raise NotImplementedError(self._aos_predictor) + + if self.loss == 'bpr': + pred = (ui_emb * aos_emb).sum(-1) + else: + pred = (ui_emb - aos_emb).pow(2).sum(-1) + + return pred + + +class HypergraphLayer(nn.Module): + """ + Hypergraph layer doing propagation along edges in the hypergraph. + + Parameters + ---------- + H: dict + Hypergraph incidence matrix for each relation relation type. I.e., positive and negative AO pairs. + in_dim: int + Input dimension + non_linear: bool + Whether to use non-linear activation function + num_layers: int + Number of layers + dropout: float + Dropout rate + aggregator: str + Aggregator to use. Can be 'sum' or mean, otherwise should be implemented. + normalize: bool, default False + Whether to normalize the output. + """ + + def __init__(self, H, in_dim, non_linear=True, num_layers=1, dropout=0, aggregator='mean', + normalize=False): + super().__init__() + self.aggregator = aggregator + self.non_linear = non_linear + self.normalize = normalize + + # Initialize matrices + self.H = None + self.D_e_inv = None + self.L_left = None + self.L_right = None + self.L = None + self.O = None + self.D_v_invsqrt = None + self.heads = None + self.tails = None + self.edges = None + self.uniques = None + + # Set matrices + self.set_matrices(H) + + # Define layers + self.num_layers = num_layers + self.in_dim = in_dim + self.W = nn.ModuleList([ + nn.ModuleDict({ + k: nn.Linear(in_dim, in_dim) for k in H + }) for _ in range(num_layers) + ]) + + # Set dropout and activation + self.dropout = nn.Dropout(dropout) + self.activation = nn.LeakyReLU() + + def set_matrices(self, H): + """ + Initialize matrices for hypergraph layer for faster computation in forward and backward pass. + Parameters + ---------- + H: dict + Hypergraph incidence matrix for each relation relation type. I.e., positive and negative AO pairs. + + Returns + ------- + None + """ + + # Set hypergraph + self.H = H + + # Compute degree matrices, node and edge-wise + d_V = {k: v.sum(1) for k, v in H.items()} + d_E = {k: v.sum(0) for k, v in H.items()} + + self.D_v_invsqrt = {k: dglsp.diag(v ** -.5) for k, v in d_V.items()} + self.D_e_inv = {k: dglsp.diag(v ** -1) for k, v in d_E.items()} + + # Compute Laplacian from the equation above. + self.L_left = {k: self.D_v_invsqrt[k] @ H[k] for k in H} + self.L_right = {k: H[k].T @ self.D_v_invsqrt[k] for k in H} + self.L = {k: self.L_left[k] @ self.D_e_inv[k] @ self.L_right[k] for k in H} + + # Out representation + self.O = {k: self.D_e_inv[k] @ H[k].T for k in H} + + def unset_matrices(self): + self.H = None + self.D_e_inv = None + self.L_left = None + self.L_right = None + self.L = None + self.O = None + self.D_v_invsqrt = None + + def forward(self, x, mask=None): + D_e = self.D_e_inv + + # Mask if in train + if mask is not None: + # Compute laplacian matrix + D_e = {k: dglsp.diag(D_e[k].val * mask) for k in D_e} + L = {k: self.L_left[k] @ D_e[k] @ self.L_right[k] for k in D_e} + else: + L = self.L + + node_out = [x] + review_out = [] + # Iterate over layers + for i, layer in enumerate(self.W): + + # Initialize in and out layers + inner_x = [] + inner_o = [] + + # Iterate over relation types (i.e., positive and negative AO pairs) + # k is type and l linear layer. + for k, l in layer.items(): + # Compute next layer + e = L[k] @ l(self.dropout(x)) + + # Apply non-linear activation + if self.non_linear: + e = self.activation(e) + + # Get node representation + o = self.O[k] @ e # average of nodes participating in review edge + + inner_x.append(e) + inner_o.append(o) + + # Combine sentiments + x = torch.stack(inner_x) + inner_o = torch.stack(inner_o) + + # Aggregate over sentiments + if self.aggregator == 'sum': + x = x.sum(0) + inner_o = inner_o.sum(0) + elif self.aggregator == 'mean': + x = x.mean(0) + inner_o = inner_o.mean(0) + else: + raise NotImplementedError(self.aggregator) + + # If using layer normalization, normalize using l2 norm. + if self.normalize: + x = x / (x.norm(2, dim=-1, keepdim=True) + 1e-5) # add epsilon to avoid division by zero. + inner_o = inner_o / (inner_o.norm(2, dim=-1, keepdim=True) + 1e-5) + + # Append representations + node_out.append(x) + review_out.append(inner_o) + + # Return aggregated representation using mean. + return torch.stack(node_out).mean(0), torch.stack(review_out).mean(0) + + +class ReviewConv(nn.Module): + """ + Review attention aggregation layer + Parameters + ---------- + aggregator: str + Aggregator to use. Can be 'gatv2' and 'narre'. + n_nodes: int + Number of nodes + in_feats: int + Input dimension + attention_feats: int + Attention dimension + num_heads: int + Number of heads + feat_drop: float, default 0. + Dropout rate for feature + attn_drop: float, default 0. + Dropout rate for attention + negative_slope: float, default 0.2 + Negative slope for LeakyReLU + activation: callable, default None + Activation function + allow_zero_in_degree: bool, default False + Whether to allow zero in degree + bias: bool, default True + Whether to include bias in linear transformations + """ + + def __init__(self, + aggregator, + n_nodes, + in_feats, + attention_feats, + num_heads, + feat_drop=0., + attn_drop=0., + negative_slope=0.2, + activation=None, + allow_zero_in_degree=False, + bias=True): + super(ReviewConv, self).__init__() + + # Set parameters + self.aggregator = aggregator + self._num_heads = num_heads + self._in_src_feats, self._in_dst_feats = dgl.utils.expand_as_pair(in_feats) + self._out_feats = attention_feats + self._allow_zero_in_degree = allow_zero_in_degree + self.fc_src = nn.Linear( + self._in_src_feats, attention_feats * num_heads, bias=bias) + + # Initialize embeddings and layers used for other methods + if self.aggregator == 'narre': + self.node_quality = nn.Embedding(n_nodes, self._in_dst_feats) + self.fc_qual = nn.Linear(self._in_dst_feats, attention_feats * num_heads, bias=bias) + elif self.aggregator == 'gatv2': + pass + else: + raise NotImplementedError(f'Not implemented any aggregator named {self.aggregator}.') + + self.attn = nn.Parameter(torch.FloatTensor(size=(1, num_heads, attention_feats))) + self.feat_drop = nn.Dropout(feat_drop) + self.attn_drop = nn.Dropout(attn_drop) + self.leaky_relu = nn.LeakyReLU(negative_slope) + self.activation = activation + self.bias = bias + + def rel_attention(self, lhs_field, rhs_field, out, w, b, source=True): + def func(edges): + idx = edges.data[rhs_field] + data = edges.src[lhs_field] if source else edges.data[lhs_field] + return {out: dgl.ops.gather_mm(data, w, idx_b=idx) + b[idx]} + return func + + def forward(self, graph, feat, get_attention=False): + """ + Description + ----------- + Compute graph attention network layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : torch.Tensor or pair of torch.Tensor + If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where + :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. + If a pair of torch.Tensor is given, the pair must contain two tensors of shape + :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. + get_attention : bool, optional + Whether to return the attention values. Default to False. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, H, D_{out})` where :math:`H` + is the number of heads, and :math:`D_{out}` is size of output feature. + torch.Tensor, optional + The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of + edges. This is returned only when :attr:`get_attention` is ``True``. + + Raises + ------ + DGLError + If there are 0-in-degree nodes in the input graph, it will raise DGLError + since no message will be passed to those nodes. This will cause invalid output. + The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``. + """ + with graph.local_scope(): + # Check if any 0-in-degree nodes + if not self._allow_zero_in_degree: + if (graph.in_degrees() == 0).any(): + raise dgl.DGLError('There are 0-in-degree nodes in the graph, ' + 'output for those nodes will be invalid. ' + 'This is harmful for some applications, ' + 'causing silent performance regression. ' + 'Adding self-loop on the input graph by ' + 'calling `g = dgl.add_self_loop(g)` will resolve ' + 'the issue. Setting ``allow_zero_in_degree`` ' + 'to be `True` when constructing this module will ' + 'suppress the check and let the code run.') + + # Drop features + h_src = self.feat_drop(feat) + + # Transform src node features to attention space + feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) + graph.srcdata.update({'el': feat_src}) # (num_src_edge, num_heads, out_dim) + + # Move messages to edges + if self.aggregator == 'narre': + # Get quality representation for user/item + h_qual = self.feat_drop(self.node_quality(graph.edata['nid'])) + + # Transform to attention space and add to edge data + feat_qual = self.fc_qual(h_qual).view(-1, self._num_heads, self._out_feats) + graph.edata.update({'qual': feat_qual}) + + # Add node and quality represenation on edges. + graph.apply_edges(fn.u_add_e('el', 'qual', 'e')) + else: + graph.apply_edges(fn.copy_u('el', 'e')) + + # Get attention representation + e = self.leaky_relu(graph.edata.pop('e'))# (num_src_edge, num_heads, out_dim) + + # Compute attention score + e = (e * self.attn).sum(dim=-1).unsqueeze(dim=2)# (num_edge, num_heads, 1) + + # Normalize attention using softmax on edges + graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) # (num_edge, num_heads) + + # If using narre set node representation to original input instead of attention representation + if self.aggregator == 'narre': + graph.srcdata.update({'el': h_src}) + + # Aggregate reviews to nodes. + graph.update_all(fn.u_mul_e('el', 'a', 'm'), + fn.sum('m', 'ft')) + rst = graph.dstdata['ft'] + + # If using activation, apply + if self.activation: + rst = self.activation(rst) + + # In inference, we may want to get the attention. If so return both review representation and attention. + if get_attention: + return rst, graph.edata['a'] + else: + return rst + + +class Model(nn.Module): + """ + HypAR model based on DGL and Torch. + Parameters + ---------- + g: dgl.DGLGraph + Heterogeneous graph with user and item nodes. + n_nodes: int + Number of nodes + aggregator: str + Aggregator to use. Can be 'gatv2' and 'narre'. + predictor: str + Predictor to use. Can be 'narre' and 'dot'. + node_dim: int + Dimension of node embeddings + incidence_dict: + Incidence matrix for each relation relation type. I.e., positive and negative AO pairs. + num_heads: int + Number of heads to use for review aggregation. + layer_dropout: list + Dropout rate for hypergraph and for review attention layer. + attention_dropout: float + Dropout rate for attention. + preference_module: str + Preference module to use. Can be 'lightgcn' and 'mf'. + use_cuda: bool + Whether we are using cuda. + combiner: str + Combiner to use. Can be 'add', 'mul', 'bi-interaction', 'concat', 'review-only', 'self', 'self-only'. + aos_predictor: str + AOS predictor to use. Can be 'non-linear' and 'transr'. + non_linear: bool + Whether to use non-linear activation function. + embedding_type: str + Type of embedding to use. Can be 'learned' and 'ao_embeddings'. + kwargs: dict + Additional arguments, such the learned embeddings. + """ + + def __init__(self, g, n_nodes, aggregator, predictor, node_dim, + incidence_dict, + num_heads, layer_dropout, attention_dropout, preference_module='lightgcn', use_cuda=True, + combiner='add', aos_predictor='non-linear', non_linear=False, embedding_type='learned', + **kwargs): + super().__init__() + from .lightgcn import Model as lightgcn + self.aggregator = aggregator + self.embedding_type = embedding_type + self.predictor = predictor + self.preference_module = preference_module + self.node_dim = node_dim + self.num_heads = num_heads + self.combiner = combiner + + if embedding_type == 'learned': + self.node_embedding = nn.Embedding(n_nodes, node_dim) + elif embedding_type == 'ao_embeddings': + self.node_embedding = nn.Embedding(n_nodes, node_dim) + self.learned_embeddings = kwargs['ao_embeddings'] + + # Layer to convert learned embeddings to node embeddings + dims = [self.learned_embeddings.size(-1), 256, 128, self.node_dim] + self.node_embedding_mlp = nn.Sequential( + *[nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.Tanh()) for i in range(len(dims)-1)] + ) + else: + raise ValueError(f'Invalid embedding type {embedding_type}') + + # Define review aggregation layer + n_layers = 3 + self.review_conv = HypergraphLayer(incidence_dict, node_dim, non_linear=non_linear, num_layers=n_layers, + dropout=layer_dropout[0]) + # Define review attention layer + self.review_agg = ReviewConv(aggregator, n_nodes, node_dim, node_dim, num_heads, + feat_drop=layer_dropout[1], attn_drop=attention_dropout) + # Define dropout + self.node_dropout = nn.Dropout(layer_dropout[0]) + + # Define preference module + self.lightgcn = lightgcn(g, node_dim, 3, 0) + + # Define out layers + self.W_s = nn.Linear(node_dim, node_dim, bias=False) + if aggregator == 'narre': + self.w_0 = nn.Linear(node_dim, node_dim) + + # Define combiner + final_dim = node_dim + assert combiner in ['add', 'mul', 'bi-interaction', 'concat', 'review-only', 'self', 'self-only'] + if combiner in ['concat', 'self']: + final_dim *= 2 # Increases out embeddings + elif combiner == 'bi-interaction': + # Add and multiply MLPs + self.add_mlp = nn.Sequential( + nn.Linear(node_dim, node_dim), + nn.Tanh() + ) + self.mul_mlp = nn.Sequential( + nn.Linear(node_dim, node_dim), + nn.Tanh() + ) + + # Define predictor + if self.predictor == 'narre': + self.edge_predictor = dgl.nn.EdgePredictor('ele', final_dim, 1, bias=True) + self.bias = nn.Parameter(torch.zeros((n_nodes, 1))) + + # Define aos predictor + self.aos_predictor = AOSPredictionLayer(aos_predictor, node_dim, final_dim, [node_dim, 64, 32], 2, + loss='transr') + + # Define loss functions + self.rating_loss_fn = nn.MSELoss(reduction='mean') + self.bpr_loss_fn = nn.Softplus() + + # Define embeddings used on inference. + self.review_embs = None + self.inf_emb = None + self.lemb = None + self.first = True + self.review_attention = None + self.ui_emb = None + self.aos_emb = None + + # Initialize parameters + self.reset_parameters() + + def reset_parameters(self): + for name, parameter in self.named_parameters(): + if name.endswith('bias'): + nn.init.constant_(parameter, 0) + else: + nn.init.xavier_normal_(parameter) + + def get_initial_embedings(self, nodes=None): + """ + Get initial embeddings for nodes. + Parameters + ---------- + nodes: torch.Tensor, optional + Nodes to get embeddings for, if none return all. + + Returns + ------- + torch.Tensor + Embeddings for nodes. + """ + + if self.embedding_type == 'learned': + # If all nodes are learned, only use node embeddings + if nodes is not None: + return self.node_embedding(nodes) + else: + return self.node_embedding.weight + elif self.embedding_type == 'ao_embeddings': + # If AO embeddings are prelearned, use them and filter rest. + + # If nodes are given select those, else use all embeddings + if nodes is not None: + filter_val = self.node_embedding.weight.size(0) + mask = nodes >= filter_val + emb = torch.empty((*nodes.size(), self.node_dim), device=nodes.device) + emb[~mask] = self.node_embedding(nodes[~mask]) # Get node embeddings from learned embeddings for UI + + # If any nodes are prelearned, get features from these. + if torch.any(mask): + emb[mask] = self.node_embedding_mlp(self.learned_embeddings[nodes[mask]-filter_val]) + return emb + else: + # Return all embeddings + return torch.cat([self.node_embedding.weight, + self.node_embedding_mlp(self.learned_embeddings)], dim=0) + else: + raise ValueError(f'Does not support {self.embedding_type}') + + def review_representation(self, x, mask=None): + """ + Compute review representation. + Parameters + ---------- + x: torch.Tensor + Input features + mask: torch.Tensor, optional + Mask to use for training. + + Returns + ------- + torch.Tensor + Review representation + """ + + return self.review_conv(x, mask) + + def review_aggregation(self, g, x, attention=False): + """ + Aggregate reviews. + Parameters + ---------- + g: dgl.DGLGraph + Graph used for aggregation + x: torch.Tensor + Input features + attention: bool, default False + Whether to return attention. + + Returns + ------- + torch.Tensor, optional attention + user or item representation based on reviews. If attention is True, return attention as well. + """ + + # Aggregate reviews + x = self.review_agg(g, x, attention) + + # Expand if using attention + if attention: + x, a = x + + # Sum over heads + x = x.sum(1) + + # Return attention if needed + if attention: + return x, a + else: + return x + + def forward(self, blocks, x, input_nodes): + """ + Forward pass for HypAR model. + Parameters + ---------- + blocks: list + List of blocks for preference module, review module and mask. + x: torch.Tensor + Input features + input_nodes: torch.Tensor + Nodes to use for input. + + Returns + ------- + torch.Tensor, torch.Tensor + Node representations used for AOS and node representations for prediction. + """ + + # Compute preference embeddings + blocks, lgcn_blocks, mask = blocks + + # First L-1 blocks for LightGCN are used for convolutions. Last maps review and preference blocks. + if self.preference_module == 'lightgcn': + u, i, _ = self.lightgcn(lgcn_blocks[:-1]) + e = {'user': u, 'item': i} + elif self.preference_module == 'mf': + # Get user/item representation without any graph convolutions. + # Use srcdata from last block to get user/item embeddings. + e = {ntype: self.lightgcn.features[ntype](nids) for ntype, nids in + lgcn_blocks[-1].srcdata[dgl.NID].items() if ntype != 'node'} + else: + raise NotImplementedError(f'{self.preference_module} is not supported') + + # Move all nodes into same sorting (non-typed) as reviews does not divide user/item by type. + g = lgcn_blocks[-1] + with g.local_scope(): + g.srcdata['h'] = e + funcs = {etype: (fn.copy_u('h', 'm'), fn.sum('m', 'h')) for etype in g.etypes} + g.multi_update_all(funcs, 'sum') + e = g.dstdata['h']['node'] + + # Compute review embeddings + x = self.node_dropout(x) + node_representation, r_ui = self.review_representation(x, mask) + + # Aggregate reviews + b, = blocks + r_ui = r_ui[b.srcdata[dgl.NID]] + r_n = self.review_aggregation(b, r_ui) # Node representation from reviews + + # Dropout + r_n, e = self.node_dropout(r_n), self.node_dropout(e) + + # Combine preference and explainability + if self.combiner == 'concat': + e_star = torch.cat([r_n, e], dim=-1) + elif self.combiner == 'add': + e_star = r_n + e + elif self.combiner == 'bi-interaction': + a = self.add_mlp(r_n + e) + m = self.mul_mlp(r_n * e) + e_star = a + m + elif self.combiner == 'mul': + e_star = r_n * e + elif self.combiner == 'review-only': + e_star = r_n + elif self.combiner == 'self': + e_star = torch.cat([r_n, node_representation[b.dstdata[dgl.NID]]], dim=-1) + elif self.combiner == 'self-only': + e_star = node_representation[b.dstdata[dgl.NID]] + + return node_representation, e_star + + def _graph_predict_dot(self, g: dgl.DGLGraph, x): + # Dot product prediction + with g.local_scope(): + g.ndata['h'] = x + g.apply_edges(fn.u_dot_v('h', 'h', 'm')) + + return g.edata['m'].reshape(-1, 1) + + def _graph_predict_narre(self, g: dgl.DGLGraph, x): + # Narre prediction methodology + with g.local_scope(): + g.ndata['b'] = self.bias[g.ndata[dgl.NID]] + g.apply_edges(fn.u_add_v('b', 'b', 'b')) # user/item bias + + u, v = g.edges() + x = self.edge_predictor(x[u], x[v]) + out = x + g.edata['b'] + + return out + + def graph_predict(self, g: dgl.DGLGraph, x): + # Predict using graph + if self.predictor == 'dot': + return self._graph_predict_dot(g, x) + elif self.predictor == 'narre': + return self._graph_predict_narre(g, x) + else: + raise ValueError(f'Predictor not implemented for "{self.predictor}".') + + def aos_graph_predict(self, g: dgl.DGLGraph, node_rep, e_star): + """ + AOS graph prediction. + Parameters + ---------- + g: dgl.DGLGraph + Graph to use for prediction. Should have edata['pos'] and edata['neg'] representing positive and negative + aspect and opinion pairs. + node_rep: torch.Tensor + Node representation for AO representation. + e_star: torch.Tensor + Node representation for user/item. + + Returns + ------- + torch.Tensor + Loss of prediction. + """ + with g.local_scope(): + # Get user/item embeddings + u, v = g.edges() + u_emb, i_emb = e_star[u], e_star[v] + + # Get positive a/o embeddings. + a, o, s = g.edata['pos'].T + a_emb, o_emb = node_rep[a], node_rep[o] + + # Predict using AOS predictor + preds_i = self.aos_predictor(u_emb, i_emb, a_emb, o_emb, s) + + # Get negative a/o embeddings + a, o, s = g.edata['neg'].permute(2, 0, 1) + a_emb, o_emb = node_rep[a], node_rep[o] + + # Predict using AOS predictor + preds_j = self.aos_predictor(u_emb, i_emb, a_emb, o_emb, s) + + # Calculate loss using bpr or transr loss (order differs). + if self.aos_predictor.loss == 'bpr': + return self.bpr_loss_fn(- (preds_i - preds_j)), preds_i > preds_j + else: + return self.bpr_loss_fn(- (preds_j - preds_i)), preds_i < preds_j + + def _predict_dot(self, u_emb, i_emb): + # Predict using dot + return (u_emb * i_emb).sum(-1) + + def _predict_narre(self, user, item, u_emb, i_emb): + # Predict using narre + h = self.edge_predictor(u_emb, i_emb) + h += (self.bias[user] + self.bias[item]) + + return h.reshape(-1, 1) + + def _combine(self, user, item): + # Use embeddings computed using self.inference + u_emb, i_emb = self.inf_emb[user], self.inf_emb[item] # review user/item embedding + lu_emb, li_emb = self.lemb[user], self.lemb[item] # preference user/item embedding, e.g., lightgcn + + # Depending on combiner, combine embeddings + # if using self or self-only, then lu/li_emb are based on explainability module only, not preference module. + if self.combiner in ['concat', 'self']: + u_emb = torch.cat([u_emb, lu_emb], dim=-1) + i_emb = torch.cat([i_emb, li_emb], dim=-1) + elif self.combiner == 'add': + u_emb += lu_emb + i_emb += li_emb + elif self.combiner == 'bi-interaction': + a = self.add_mlp(u_emb + lu_emb) + m = self.mul_mlp(u_emb * lu_emb) + u_emb = a + m + a = self.add_mlp(i_emb + li_emb) + m = self.mul_mlp(i_emb * li_emb) + i_emb = a + m + elif self.combiner == 'mul': + u_emb *= lu_emb + i_emb *= li_emb + elif self.combiner == 'review-only': + pass + elif self.combiner == 'self-only': + u_emb, i_emb = lu_emb, li_emb + + # Return user item embeddings + return u_emb, i_emb + + def predict(self, user, item): + """ + Predict using model. + Parameters + ---------- + user: torch.Tensor + User ids + item: torch.Tensor + Item ids + + Returns + ------- + torch.Tensor + Predicted ranking/rating. + """ + u_emb, i_emb = self._combine(user, item) + + if self.predictor == 'dot': + pred = self._predict_dot(u_emb, i_emb) + elif self.predictor == 'narre': + pred = self._predict_narre(user, item, u_emb, i_emb) + else: + raise ValueError(f'Predictor not implemented for "{self.predictor}".') + + return pred + + def rating_loss(self, preds, target): + return self.rating_loss_fn(preds, target.unsqueeze(-1)) + + def ranking_loss(self, preds_i, preds_j, loss_fn='bpr'): + if loss_fn == 'bpr': + loss = self.bpr_loss_fn(- (preds_i - preds_j)) + else: + raise NotImplementedError + + return loss.mean() + + def inference(self, node_review_graph, ui_graph, device, batch_size): + """ + Inference for HypAR model. + Parameters + ---------- + node_review_graph: dgl.DGLGraph + Graph mapping reviews to nodes. + ui_graph: dgl.DGLGraph + Graph with user/item mappings + device: str + Device to use for inference. + batch_size: int + Batch size to use for inference. + + Returns + ------- + None + """ + + # Review inference. nx is the node representation. + x = self.get_initial_embedings() + nx, self.review_embs = self.review_representation(x) + + # Node inference setup + indices = {'node': node_review_graph.nodes('node')} + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) + dataloader = dgl.dataloading.DataLoader(node_review_graph, indices, sampler, batch_size=batch_size, shuffle=False, + drop_last=False, device=device) + + # Initialize embeddings + self.inf_emb = torch.zeros((torch.max(indices['node'])+1, self.node_dim)).to(device) + self.review_attention = torch.zeros((node_review_graph.num_edges(), self.review_agg._num_heads, 1)).to(device) + + # Aggregate reviews using attention + for input_nodes, output_nodes, blocks in dataloader: + x, a = self.review_aggregation(blocks[0]['part_of'], self.review_embs[input_nodes['review']], True) + self.inf_emb[output_nodes['node']] = x + self.review_attention[blocks[0]['part_of'].edata[dgl.EID]] = a + + # Node preference embedding + if self.preference_module == 'lightgcn': + u, i, _ = self.lightgcn(ui_graph) + x = {'user': u, 'item': i} + else: + x = {nt: e.weight for nt, e in self.lightgcn.features.items()} + + # Combine/stack useritem embeddings + if self.combiner.startswith('self'): + x = nx + else: + x = torch.cat([x['item'], x['user']], dim=0) + + # Set embeddings for prediction + self.lemb = x + + + + diff --git a/cornac/models/hypar/lightgcn.py b/cornac/models/hypar/lightgcn.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe37a460108fb7fdb48affef04aeb34126aa1c8 --- /dev/null +++ b/cornac/models/hypar/lightgcn.py @@ -0,0 +1,135 @@ +from typing import Union, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import dgl +import dgl.function as fn + +USER_KEY = "user" +ITEM_KEY = "item" + + +def construct_graph(data_set, total_users, total_items): + """ + Generates graph given a cornac data set + + Parameters + ---------- + data_set : cornac.data.dataset.Dataset + The data set as provided by cornac + """ + user_indices, item_indices, _ = data_set.uir_tuple + + data_dict = { + (USER_KEY, "user_item", ITEM_KEY): (user_indices, item_indices), + (ITEM_KEY, "item_user", USER_KEY): (item_indices, user_indices), + } + num_dict = {USER_KEY: total_users, ITEM_KEY: total_items} + + g = dgl.heterograph(data_dict, num_nodes_dict=num_dict) + norm_dict = {} + for srctype, etype, dsttype in g.canonical_etypes: + src, dst = g.edges(etype=(srctype, etype, dsttype)) + dst_degree = g.in_degrees( + dst, etype=(srctype, etype, dsttype) + ).float() # obtain degrees + src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float() + norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm + g.edata['norm'] = {etype: norm} + + return g + + +class GCNLayer(nn.Module): + def __init__(self): + super(GCNLayer, self).__init__() + + def forward(self, g, feat_dict): + funcs = {} # message and reduce functions dict + # for each type of edges, compute messages and reduce them all + g.ndata["h"] = feat_dict + for srctype, etype, dsttype in g.canonical_etypes: + funcs[(srctype, etype, dsttype)] = ( + fn.u_mul_e("h", "norm", "m"), + fn.sum("m", "h_n"), + ) # define message and reduce functions + + g.multi_update_all( + funcs, "sum" + ) # update all, reduce by first type-wisely then across different types + return g.dstdata["h_n"] + + +class Model(nn.Module): + def __init__(self, g, in_size, num_layers, lambda_reg, device=None): + super(Model, self).__init__() + self.norm_dict = dict() + self.lambda_reg = lambda_reg + self.device = device + + self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)]) + + self.initializer = nn.init.xavier_uniform_ + + # embeddings for different types of nodes + self.feature_dict = nn.ParameterDict( + { + ntype: nn.Parameter( + self.initializer(torch.empty(g.num_nodes(ntype), in_size)) + ) + for ntype in g.ntypes + } + ) + + def forward(self, in_g: Union[dgl.DGLGraph, List[dgl.DGLGraph]], users=None, pos_items=None, neg_items=None): + + if isinstance(in_g, list): + h_dict = {ntype: self.feature_dict[ntype][in_g[0].ndata[dgl.NID][ntype]] for ntype in in_g[0].ntypes} + user_embeds = h_dict[USER_KEY][in_g[-1].dstnodes(USER_KEY)] + item_embeds = h_dict[ITEM_KEY][in_g[-1].dstnodes(ITEM_KEY)] + iterator = enumerate(zip(in_g, self.layers)) + else: + h_dict = {ntype: self.feature_dict[ntype] for ntype in in_g.ntypes} + # obtain features of each layer and concatenate them all + user_embeds = h_dict[USER_KEY] + item_embeds = h_dict[ITEM_KEY] + iterator = enumerate(zip([in_g] * len(self.layers), self.layers)) + + for k, (g, layer) in iterator: + h_dict = layer(g, h_dict) + ue = h_dict[USER_KEY] + ie = h_dict[ITEM_KEY] + + if isinstance(in_g, list): + ue = ue[in_g[-1].dstnodes(USER_KEY)] + ie = ie[in_g[-1].dstnodes(ITEM_KEY)] + + user_embeds = user_embeds + ue + item_embeds = item_embeds + ie + + user_embeds = user_embeds / (len(self.layers) + 1) + item_embeds = item_embeds / (len(self.layers) + 1) + + u_g_embeddings = user_embeds if users is None else user_embeds[users, :] + pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :] + neg_i_g_embeddings = item_embeds if neg_items is None else item_embeds[neg_items, :] + + return u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings + + def loss_fn(self, users, pos_items, neg_items): + pos_scores = (users * pos_items).sum(1) + neg_scores = (users * neg_items).sum(1) + + bpr_loss = F.softplus(neg_scores - pos_scores).mean() + reg_loss = ( + (1 / 2) + * ( + torch.norm(users) ** 2 + + torch.norm(pos_items) ** 2 + + torch.norm(neg_items) ** 2 + ) + / len(users) + ) + + return bpr_loss + self.lambda_reg * reg_loss, bpr_loss, reg_loss diff --git a/cornac/models/hypar/recom_hypar.py b/cornac/models/hypar/recom_hypar.py new file mode 100644 index 0000000000000000000000000000000000000000..b8ffa03012e83df9b02bf5e61d3fb0f2bcbf34ba --- /dev/null +++ b/cornac/models/hypar/recom_hypar.py @@ -0,0 +1,843 @@ +import collections +import os +import pickle +from collections import defaultdict +from contextlib import nullcontext +from copy import deepcopy + +from ..recommender import Recommender +from ...data import Dataset + + +class HypAR(Recommender): + """ + HypAR: Hypergraph with Attention on Review. This model is from the paper "Hypergraph with Attention on Reviews + for explainable recommendation", by Theis E. Jendal, Trung-Hoang Le, Hady W. Lauw, Matteo Lissandrini, + Peter Dolog, and Katja Hose. + ECIR 2024: https://doi.org/10.1007/978-3-031-56027-9_14 + + Parameters + ---------- + name: str, default: 'HypAR' + Name of the model. + use_cuda: bool, default: False + Whether to use cuda. + stemming: bool, default: True + Whether to use stemming. + batch_size: int, default: 128 + Batch size. + num_workers: int, default: 0 + Number of workers for dataloader. + num_epochs: int, default: 10 + Number of epochs. + early_stopping: int, default: 10 + Early stopping. + eval_interval: int, default: 1 + Evaluation interval, i.e., how often to evaluate on the validation set. + learning_rate: float, default: 0.1 + Learning rate. + weight_decay: float, default: 0 + Weight decay. + node_dim: int, default: 64 + Dimension of learned and hidden layers. + num_heads: int, default: 3 + Number of attention heads. + fanout: int, default: 5 + Fanout for sampling. + non_linear: bool, default: True + Whether to use non-linear activation function. + model_selection: str, default: 'best' + Model selection method, i.e., whether to use the best model or the last model. + objective: str, default: 'ranking' + Objective, i.e., whether to use ranking or rating. + review_aggregator: str, default: 'narre' + Review aggregator, i.e., how to aggregate reviews. + predictor: str, default: 'narre' + Predictor, i.e., how to predict ratings. + preference_module: str, default: 'lightgcn' + Preference module, i.e., how to model preferences. + combiner: str, default: 'add' + Combiner, i.e., how to combine embeddings. + graph_type: str, default: 'aos' + Graph type, i.e., which nodes to include in hypergraph. Aspects, opinions and sentiment. + num_neg_samples: int, default: 50 + Number of negative samples to use for ranking. + layer_dropout: float, default: None + Dropout for node and review embeddings. + attention_dropout: float, default: .2 + Dropout for attention. + user_based: bool, default: True + Whether to use user-based or item-based. + verbose: bool, default: True + Whether to print information. + index: int, default: 0 + Index for saving results, i.e., if hyparparameter tuning. + out_path: str, default: None + Path to save graphs, embeddings and similar. + learn_explainability: bool, default: False + Whether to learn explainability. + learn_method: str, default: 'transr' + Learning method, i.e., which method to use explainability learning. + learn_weight: float, default: 1. + Weight for explainability learning loss. + embedding_type: str, default: 'ao_embeddings' + Type of embeddings to use, i.e., whether to use prelearned embeddings or not. + debug: bool, default: False + Whether to use debug mode as errors might be thrown by dataloaders when debugging. + """ + def __init__(self, + name='HypAR', + use_cuda=False, + stemming=True, + batch_size=128, + num_workers=0, + num_epochs=10, + early_stopping=10, + eval_interval=1, + learning_rate=0.1, + weight_decay=0, + node_dim=64, + num_heads=3, + fanout=5, + non_linear=True, + model_selection='best', + objective='ranking', + review_aggregator='narre', + predictor='narre', + preference_module='lightgcn', + combiner='add', + graph_type='aos', + num_neg_samples=50, + layer_dropout=None, + attention_dropout=.2, + user_based=True, + verbose=True, + index=0, + out_path=None, + learn_explainability=False, + learn_method='transr', + learn_weight=1., + embedding_type='ao_embeddings', + debug=False, + ): + super().__init__(name) + # Default values + if layer_dropout is None: + layer_dropout = 0. # node embedding dropout, review embedding dropout + + # CUDA + self.use_cuda = use_cuda + self.device = 'cuda' if use_cuda else 'cpu' + + # Parameters + self.batch_size = batch_size + self.num_workers = num_workers + self.num_epochs = num_epochs + self.early_stopping = early_stopping + self.eval_interval = eval_interval + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.node_dim = node_dim + self.num_heads = num_heads + self.fanout = fanout + self.non_linear = non_linear + self.model_selection = model_selection + self.objective = objective + self.review_aggregator = review_aggregator + self.predictor = predictor + self.preference_module = preference_module + self.combiner = combiner + self.graph_type = graph_type + self.num_neg_samples = num_neg_samples + self.layer_dropout = layer_dropout + self.attention_dropout = attention_dropout + self.stemming = stemming + self.learn_explainability = learn_explainability + self.learn_method = learn_method + self.learn_weight = learn_weight + self.embedding_type = embedding_type + + # Method + self.node_review_graph = None + self.review_graphs = {} + self.train_graph = None + self.ui_graph = None + self.model = None + self.n_items = 0 + self.n_relations = 0 + self.ntype_ranges = None + self.node_filter = None + self.sid_aos = None + self.aos_tensor = None + + # Misc + self.user_based = user_based + self.verbose = verbose + self.debug = debug + self.index = index + self.out_path = out_path + + # assertions + assert objective == 'ranking' or objective == 'rating', f'This method only supports ranking or rating, ' \ + f'not {objective}.' + if early_stopping is not None: + assert early_stopping % eval_interval == 0, 'interval should be a divisor of early stopping value.' + + def _create_graphs(self, train_set: Dataset, graph_type='aos'): + """ + Create graphs required for training and returns all relevant data for future computations. + Parameters + ---------- + train_set: Dataset + graph_type: str, which can contain a,o and s, where a is aspect, o is opinion and s is sentiment. E.g., if + a or o, then aspect and opinion are included, if s, then splitting on sentiment is included. + + Returns + ------- + num nodes, num node types, num items, train graph, hyper edges, node review graph, type ranges, sid to aos, and + aos triple list. + """ + import dgl + import torch + from tqdm import tqdm + from .dgl_utils import generate_mappings + + sentiment_modality = train_set.sentiment + n_users = len(train_set.uid_map) + n_items = len(train_set.iid_map) + + # Group and prune aspects and opinions + _, _, _, _, _, _, a2a, o2o = generate_mappings(train_set.sentiment, 'a', get_ao_mappings=True) + + # Get num and depending on graphtype, calculate tot num of embeddings + n_aspects = max(a2a.values()) + 1 if self.stemming else len(sentiment_modality.aspect_id_map) + n_opinions = max(o2o.values()) + 1 if self.stemming else len(sentiment_modality.opinion_id_map) + n_nodes = n_users + n_items + n_types = 4 + if 'a' in graph_type: + n_nodes += n_aspects + n_types += 1 + if 'o' in graph_type: + n_nodes += n_opinions + n_types += 1 + + # Map users to review ids. + user_item_review_map = {(uid + n_items, iid): rid for uid, irid in sentiment_modality.user_sentiment.items() + for iid, rid in irid.items()} + + # Initialize relevant lists + review_edges = [] + ratings = [] + if 's' in graph_type: + hyper_edges = {'p': [], 'n': []} + else: + hyper_edges = {'n': []} + sent_mapping = {-1: 'n', 1: 'p'} + + # Create review edges and ratings + sid_map = {sid: i for i, sid in enumerate(train_set.sentiment.sentiment)} + for uid, isid in tqdm(sentiment_modality.user_sentiment.items(), desc='Creating review graphs', + total=len(sentiment_modality.user_sentiment), disable=not self.verbose): + uid += n_items # Shift to user node id. + + for iid, sid in isid.items(): + # Sid is used as edge id, i.e., each sid represent a single review. + first_sentiment = {'p': True, 'n': True} # special handling for initial sentiment. + review_edges.extend([[sid, uid], [sid, iid]]) # Add u/i to review aggregation + ratings.extend([train_set.matrix[uid - n_items, iid]] * 2) # Add to rating list + aos = sentiment_modality.sentiment[sid] # get aspects, opinions and sentiments for review. + for aid, oid, s in aos: + # Map sentiments if using else, use default (n). + if 's' in graph_type: + sent = sent_mapping[s] + else: + sent = 'n' + + # If stemming and pruning data, use simplified id (i.e., mapping) + if self.stemming: + aid = a2a[aid] + oid = o2o[oid] + + # Add to hyper edges, i.e., connect user, item, aspect and opinion to sentiment. + if first_sentiment[sent]: + hyper_edges[sent].extend([(iid, sid), (uid, sid)]) + first_sentiment[sent] = False + + # Shift aspect and opinion ids to node id. + aid += n_items + n_users + oid += n_items + n_users + + # If using aspect/opinion, add to hyper edges. + if 'a' in graph_type: + hyper_edges[sent].append((aid, sid)) + oid += n_aspects # Shift opinion id to correct node id if using both aspect and opinion. + if 'o' in graph_type: + hyper_edges[sent].append((oid, sid)) + + # Convert to tensor + for k, v in hyper_edges.items(): + hyper_edges[k] = torch.LongTensor(v).T + + # Create training graph, i.e. user to item graph. + edges = [(uid + n_items, iid, train_set.matrix[uid, iid]) for uid, iid in zip(*train_set.matrix.nonzero())] + t_edges = torch.LongTensor(edges).T + train_graph = dgl.graph((t_edges[0], t_edges[1])) + train_graph.edata['sid'] = torch.LongTensor([user_item_review_map[(u, i)] for (u, i, r) in edges]) + train_graph.edata['label'] = t_edges[2].to(torch.float) + + # Create user/item to review graph. + edges = torch.LongTensor(review_edges).T + node_review_graph = dgl.heterograph({('review', 'part_of', 'node'): (edges[0], edges[1])}) + + # Assign edges node_ids s.t. an edge from user to review has the item nid its about and reversely. + node_review_graph.edata['nid'] = torch.LongTensor(node_review_graph.num_edges()) + _, v, eids = node_review_graph.edges(form='all') + node_review_graph.edata['nid'][eids % 2 == 0] = v[eids % 2 == 1] + node_review_graph.edata['nid'][eids % 2 == 1] = v[eids % 2 == 0] + + # Scale ratings with denominator if not integers. I.e., if .25 multiply by 4. + # A mapping from frac to int. Thus if scale is from 1-5 and in .5 increments, will be converted to 1-10. + denominators = [e.as_integer_ratio()[1] for e in ratings] + i = 0 + while any(d != 1 for d in denominators): + ratings = ratings * max(denominators) + denominators = [e.as_integer_ratio()[1] for e in ratings] + i += 1 + assert i < 100, 'Tried to convert ratings to integers but took to long.' + + node_review_graph.edata['r_type'] = torch.LongTensor(ratings) - 1 + + # Define ntype ranges + ntype_ranges = {'item': (0, n_items), 'user': (n_items, n_items + n_users)} + start = n_items + n_users + if 'a' in graph_type: + ntype_ranges['aspect'] = (start, start + n_aspects) + start += n_aspects + if 'o' in graph_type: + ntype_ranges['opinion'] = (start, start + n_opinions) + + # Get all aos triples + sid_aos = [] + for sid in range(max(train_set.sentiment.sentiment) + 1): + aoss = train_set.sentiment.sentiment.get(sid, []) + sid_aos.append([(a2a[a] + n_items + n_users, o2o[o] + n_users + n_items + n_aspects, 0 if s == -1 else 1) + for a, o, s in aoss]) + + + aos_list = sorted({aos for aoss in sid_aos for aos in aoss}) + aos_id = {aos: i for i, aos in enumerate(aos_list)} + sid_aos = [torch.LongTensor([aos_id[aos] for aos in aoss]) for aoss in sid_aos] + + return n_nodes, n_types, n_items, train_graph, hyper_edges, node_review_graph, ntype_ranges, sid_aos, aos_list + + def _flock_wrapper(self, func, fname, *args, rerun=False, **kwargs): + """ + Wrapper for loading and saving data without accidental overrides and dual computation when running in parallel. + If file exists, load, else run function and save. + Parameters + ---------- + func: function + Function to run. + fname: str + File name to save/load. + args: list + Arguments to function. + rerun: bool, default: False + If true, rerun function. + kwargs: dict + Keyword arguments to function. + + Returns + ------- + Data from function. + """ + from filelock import FileLock + + fpath = os.path.join(self.out_path, fname) + lock_fpath = os.path.join(self.out_path, fname + '.lock') + + with FileLock(lock_fpath): + if not rerun and os.path.exists(fpath): + with open(fpath, 'rb') as f: + data = pickle.load(f) + else: + data = func(*args, **kwargs) + with open(fpath, 'wb') as f: + pickle.dump(data, f) + + return data + + def _graph_wrapper(self, train_set, graph_type, *args): + """ + Wrapper for creating graphs and converting to correct format. + Assigns values to self, such as train graph, review graphs, node review graph, and ntype ranges. + Define self.node_filter based on type ranges. + Parameters + ---------- + train_set: Dataset + Dataset to use for graph construction + graph_type: str + Which graph to create. Can contain a, o and s, where a is aspect, o is opinion and s is sentiment. + args: list + Additional arguments to graph creation function. + + Returns + ------- + Num nodes, num types, sid to aos mapping, list of aos triples. + """ + import dgl.sparse as dglsp + import torch + + # Load graph data + fname = f'graph_{graph_type}_data.pickle' + data = self._flock_wrapper(self._create_graphs, fname, train_set, graph_type, *args, rerun=False) + + # Expland data and assign to self + n_nodes, n_types, self.n_items, self.train_graph, self.review_graphs, self.node_review_graph, \ + self.ntype_ranges, sid_aos, aos_list = data + + # Convert data to sparse matrices and assign to self. + # Review graphs is dict with positive/negative sentiment (possibly). + shape = torch.cat(list(self.review_graphs.values()), dim=-1).max(-1)[0] + 1 + for k, edges in self.review_graphs.items(): + H = dglsp.spmatrix( + torch.unique(edges, dim=1), shape=shape.tolist() + ).coalesce() + assert (H.val == 1).all() + self.review_graphs[k] = H.to(self.device) + + self.node_filter = lambda t, nids: (nids >= self.ntype_ranges[t][0]) * (nids < self.ntype_ranges[t][1]) + return n_nodes, n_types, sid_aos, aos_list + + def _ao_embeddings(self, train_set): + """ + Learn aspect and opinion embeddings using word2vec. + Parameters + ---------- + train_set: dataset + Dataset to use for learning embeddings. + Returns + ------- + Aspect and opinion embeddings, and word2vec model. + """ + from .dgl_utils import generate_mappings, stem_fn + from gensim.models import Word2Vec + from gensim.parsing import remove_stopwords, preprocess_string, stem_text + from nltk.tokenize import word_tokenize + from tqdm import tqdm + import numpy as np + + sentiment = train_set.sentiment + + # Define preprocess functions for text, aspects and opinions. + preprocess_fn = stem_fn + + # Process corpus, getting all sentences and words. + corpus = [] + for review in tqdm(train_set.review_text.corpus, desc='Processing text', disable=not self.verbose): + for sentence in review.split('.'): + words = word_tokenize(sentence.replace(' n\'t ', 'n ').replace('/', ' ')) + corpus.append(' '.join(preprocess_fn(word) for word in words)) + + # Process words to match with aos extraction methodology used in SEER. + a_old_new_map = {a: preprocess_fn(a) for a in sentiment.aspect_id_map} + o_old_new_map = {o: preprocess_fn(o) for o in sentiment.opinion_id_map} + + # Generate mappings for aspect and opinion ids. + _, _, _, _, _, _, a2a, o2o = generate_mappings(train_set.sentiment, 'a', get_ao_mappings=True) + + # Define a progressbar for training word2vec as no information is displayed without. + class CallbackProgressBar: + def __init__(self, verbose): + self.verbose = verbose + self.progress = None + + def on_train_begin(self, method): + if self.progress is None: + self.progress = tqdm(desc='Training Word2Vec', total=method.epochs, disable=not self.verbose) + + def on_train_end(self, method): + pass + + def on_epoch_begin(self, method): + pass + + def on_epoch_end(self, method): + self.progress.update(1) + + # Split words on space and get all unique words + wc = [s.split(' ') for s in corpus] + all_words = set(s for se in wc for s in se) + + # Assert all aspects and opinions in dataset are in corpus. If not, print missing words. + # New datasets may require more preprocessing. + assert all([a in all_words for a in a_old_new_map.values()]), [a for a in a_old_new_map.values() if + a not in all_words] + assert all([o in all_words for o in o_old_new_map.values()]), [o for o in o_old_new_map.values() if + o not in all_words] + + # Train word2vec model using callbacks for progressbar. + l = CallbackProgressBar(self.verbose) + embedding_dim = 100 + w2v_model = Word2Vec(wc, vector_size=embedding_dim, min_count=1, window=5, callbacks=[l], epochs=100) + + # Keyvector model + kv = w2v_model.wv + + # Initialize embeddings + a_embeddings = np.zeros((len(set(a2a.values())), embedding_dim)) + o_embeddings = np.zeros((len(set(o2o.values())), embedding_dim)) + + # Define function for assigning embeddings to correct aspect. + def get_info(old_new_pairs, mapping, embedding): + for old, new in old_new_pairs: + nid = mapping(old) + vector = np.array(kv.get_vector(new)) + embedding[nid] = vector + + return embedding + + # Assign embeddings to correct aspect and opinion. + a_embeddings = get_info(a_old_new_map.items(), lambda x: a2a[sentiment.aspect_id_map[x]], a_embeddings) + o_embeddings = get_info(o_old_new_map.items(), lambda x: o2o[sentiment.opinion_id_map[x]], o_embeddings) + + return a_embeddings, o_embeddings, kv + + def _normalize_embedding(self, embedding): + """ + Normalize embeddings using standard scaler. + Parameters + ---------- + embedding: np.array + Embedding to normalize. + + Returns + ------- + Normalized embedding and scaler. + """ + from sklearn.preprocessing import StandardScaler + scaler = StandardScaler() + scaler.fit(embedding) + return scaler.transform(embedding), scaler + + def _learn_initial_ao_embeddings(self, train_set): + """ + Learn initial aspect and opinion embeddings. + Parameters + ---------- + train_set: Dataset + Dataset to use for learning embeddings. + + Returns + ------- + Aspect and opinion embeddings as torch tensors. + """ + + import torch + + ao_fname = 'ao_embeddingsv2.pickle' + a_fname = 'aspect_embeddingsv2.pickle' + o_fname = 'opinion_embeddingsv2.pickle' + + # Get embeddings and store result + a_embeddings, o_embeddings, _ = self._flock_wrapper(self._ao_embeddings, ao_fname, train_set) + + # Scale embeddings and store results. Function returns scaler, which is not needed, but required if new data is + # added. + a_embeddings, _ = self._flock_wrapper(self._normalize_embedding, a_fname, a_embeddings) + o_embeddings, _ = self._flock_wrapper(self._normalize_embedding, o_fname, o_embeddings) + + return torch.tensor(a_embeddings), torch.tensor(o_embeddings) + + def fit(self, train_set: Dataset, val_set=None): + import torch + from .lightgcn import construct_graph + + # Initialize self variables + super().fit(train_set, val_set) + + # Create graphs and assigns to self (e.g., see self.review_graphs). + n_nodes, self.n_relations, self.sid_aos, self.aos_list = self._graph_wrapper(train_set, + self.graph_type) + + # If using learned ao embeddings, learn and assign to kwargs + kwargs = {} + if self.embedding_type == 'ao_embeddings': + a_embs, o_embs = self._learn_initial_ao_embeddings(train_set) + + emb = [] + if 'a' in self.graph_type: + emb.append(a_embs) + if 'o' in self.graph_type: + emb.append(o_embs) + + if len(emb): + kwargs['ao_embeddings'] = torch.cat(emb).to(self.device).to(torch.float32) + n_nodes -= kwargs['ao_embeddings'].size(0) + else: + kwargs['ao_embeddings'] = torch.zeros((0, 0)) + + self.n_relations = 0 + + # Construct user-item graph used by lightgcn + self.ui_graph = construct_graph(train_set, self.num_users, self.num_items) + + # create model + from .hypar import Model + + self.model = Model(self.ui_graph, n_nodes, self.review_aggregator, + self.predictor, self.node_dim, self.review_graphs, self.num_heads, [self.layer_dropout] * 2, + self.attention_dropout, self.preference_module, self.use_cuda, combiner=self.combiner, + aos_predictor=self.learn_method, non_linear=self.non_linear, + embedding_type=self.embedding_type, + **kwargs) + + self.model.reset_parameters() + + if self.verbose: + print(f'Number of trainable parameters: {sum(p.numel() for p in self.model.parameters())}') + + if self.use_cuda: + self.model = self.model.cuda() + prefetch = ['label'] + else: + prefetch = [] + + # Train model + if self.trainable: + self._fit(prefetch, val_set) + + return self + + def _fit(self, prefetch, val_set=None): + import dgl + import torch + from torch import optim + from . import dgl_utils + import cornac + from tqdm import tqdm + + # Get graph and edges + g = self.train_graph + u, v = g.edges() + _, i, c = torch.unique(u, sorted=False, return_inverse=True, return_counts=True) + mask = c[i] > 1 + _, i, c = torch.unique(v, sorted=False, return_inverse=True, return_counts=True) + mask *= (c[i] > 1) + eids = g.edges(form='eid')[mask] + num_workers = self.num_workers + + if self.debug: + num_workers = 0 + + thread = False # Memory saving and does not increase speed. + + # Create sampler + sampler = dgl_utils.HypARBlockSampler(self.node_review_graph, self.review_graphs, self.review_aggregator, + self.sid_aos, self.aos_list, 5, + self.ui_graph, fanout=self.fanout) + + # If trained for ranking, define negative sampler only sampling items as negative samples. + if self.objective == 'ranking': + ic = collections.Counter(self.train_set.matrix.nonzero()[1]) + neg_sampler = dgl_utils.GlobalUniformItemSampler(self.num_neg_samples, self.train_set.num_items,) + else: + neg_sampler = None + + # Initialize sampler and dataloader + sampler = dgl_utils.HypAREdgeSampler(sampler, prefetch_labels=prefetch, negative_sampler=neg_sampler, + exclude='self') + dataloader = dgl.dataloading.DataLoader(g, eids, sampler, batch_size=self.batch_size, shuffle=True, + drop_last=True, device=self.device, + num_workers=num_workers, use_prefetch_thread=thread) + + # Initialize training params. + optimizer = optim.AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + # Define metrics + if self.objective == 'ranking': + metrics = [cornac.metrics.NDCG(), cornac.metrics.AUC(), cornac.metrics.MAP(), cornac.metrics.MRR()] + else: + metrics = [cornac.metrics.MSE()] + + # Initialize variables for training + best_state = None + best_score = 0 if metrics[0].higher_better else float('inf') + best_epoch = 0 + epoch_length = len(dataloader) + all_nodes = torch.arange(next(iter(self.review_graphs.values())).shape[0]).to(self.device) + + # Train model + for e in range(self.num_epochs): + # Initialize for logging + tot_losses = defaultdict(int) + cur_losses = {} + self.model.train() + with tqdm(dataloader, disable=not self.verbose) as progress: + for i, batch in enumerate(progress, 1): + # Batch depends on objective + if self.objective == 'ranking': + input_nodes, edge_subgraph, neg_subgraph, blocks = batch + else: + input_nodes, edge_subgraph, blocks = batch + + # Get node representations and review representations + node_representation, e_star = self.model(blocks, self.model.get_initial_embedings(all_nodes), input_nodes) + + # Get preiction based on graph structure (edges represents ratings, thus predictions) + pred = self.model.graph_predict(edge_subgraph, e_star) + loss = 0 + if self.objective == 'ranking': + # Calculate predictions for negative subgraph and calculate ranking loss. + pred_j = self.model.graph_predict(neg_subgraph, e_star) + pred_j = pred_j.reshape(-1, self.num_neg_samples) + loss = self.model.ranking_loss(pred, pred_j) + + # Calculate accuracy + acc = (pred > pred_j).sum() / pred_j.shape.numel() + cur_losses['acc'] = acc.detach() + else: + # Calculate rating loss, if using prediction instead of ranking. + loss = self.model.rating_loss(pred, edge_subgraph.edata['label']) + + cur_losses['lloss'] = loss.clone().detach() # Learning loss + + # If using explainability, calculate loss and accuracy for explainability. + if self.learn_explainability: + aos_loss, aos_acc = self.model.aos_graph_predict(edge_subgraph, node_representation, e_star) + aos_loss = aos_loss.mean() + cur_losses['aos_loss'] = aos_loss.detach() + cur_losses['aos_acc'] = (aos_acc.sum() / aos_acc.shape.numel()).detach() + loss += self.learn_weight * aos_loss # Add to loss with weight. + + cur_losses['totloss'] = loss.detach() + loss.backward() + + # Update batch losses + for k, v in cur_losses.items(): + tot_losses[k] += v.cpu() + + # Update model + optimizer.step() + optimizer.zero_grad() + + # Define printing + loss_str = ','.join([f'{k}:{v / i:.3f}' for k, v in tot_losses.items()]) + + # If not validating, else + if i != epoch_length or val_set is None: + progress.set_description(f'Epoch {e}, ' + loss_str) + elif (e + 1) % self.eval_interval == 0: + # If validating, validate and print results. + results = self._validate(val_set, metrics) + res_str = 'Val: ' + ','.join([f'{m.name}:{r:.4f}' for m, r in zip(metrics, results)]) + progress.set_description(f'Epoch {e}, ' + f'{loss_str}, ' + res_str) + + # If use best state and new best score, save state. + if self.model_selection == 'best' and \ + (results[0] > best_score if metrics[0].higher_better else results[0] < best_score): + best_state = deepcopy(self.model.state_dict()) + best_score = results[0] + best_epoch = e + + # Stop if no improvement. + if self.early_stopping is not None and (e - best_epoch) >= self.early_stopping: + break + + # Space efficiency + del self.node_filter + del g, eids + del dataloader + del sampler + + # Load best state if using best state. + if best_state is not None: + self.model.load_state_dict(best_state) + + # Do inference calculation + self.model.eval() + with torch.no_grad(): + self.model.inference(self.node_review_graph, self.ui_graph, self.device, + self.batch_size) + + # Set self values + self.best_epoch = best_epoch + self.best_value = best_score + + def _validate(self, val_set, metrics): + from ...eval_methods.base_method import rating_eval, ranking_eval + import torch + + # Do inference calculation + self.model.eval() + with torch.no_grad(): + self.model.inference(self.node_review_graph, self.ui_graph, self.device, + self.batch_size) + + # Evaluate model + if self.objective == 'ranking': + (result, _) = ranking_eval(self, metrics, self.train_set, val_set) + else: + (result, _) = rating_eval(self, metrics, val_set, user_based=self.user_based) + + # Return best validation score + return result + + def score(self, user_idx, item_idx=None): + import torch + + # Ensure model is in evaluation mode and not calculating gradient. + self.model.eval() + with torch.no_grad(): + # Shift user ids + user_idx = torch.tensor(user_idx + self.n_items, dtype=torch.int64).to(self.device) + + # If item_idx is None, predict all items, else predict only item_idx. + if item_idx is None: + item_idx = torch.arange(self.n_items, dtype=torch.int64).to(self.device) + pred = self.model.predict(user_idx, item_idx).reshape(-1).cpu().numpy() + else: + item_idx = torch.tensor(item_idx, dtype=torch.int64).to(self.device) + pred = self.model.predict(user_idx, item_idx).cpu() + + # Return predictions + return pred + + def monitor_value(self, train_set, val_set=None): + pass + + def save(self, save_dir=None, save_trainset=False): + import torch + + if save_dir is None: + return + + # Unset matrices to avoid pickling errors. Convert for review graphs due to same issues. + self.model.review_conv.unset_matrices() + self.review_graphs = {k: (v.row, v.col, v.shape) for k, v in self.review_graphs.items()} + + # Save model + path = super().save(save_dir, save_trainset) + name = path.rsplit('/', 1)[-1].replace('pkl', 'pt') + + # Save state dict, only necessary if state should be used outside of class. Thus, not part of load. + state = self.model.state_dict() + torch.save(state, os.path.join(save_dir, str(self.index), name)) + + return path + + def load(self, model_path, trainable=False): + import dgl.sparse as dglsp + import torch + + # Load model + model = super().load(model_path, trainable) + + # Convert review graphs to sparse matrices + for k, v in model.review_graphs.items(): + model.review_graphs[k] = dglsp.spmatrix(torch.stack([v[0], v[1]]), shape=v[2]).coalesce().to(model.device) + + # Set matrices. + model.model.review_conv.set_matrices(model.review_graphs) + + return model diff --git a/cornac/models/hypar/requirements.txt b/cornac/models/hypar/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..173bfa83b1821d1ff10cdc9a900911bf78e26dff --- /dev/null +++ b/cornac/models/hypar/requirements.txt @@ -0,0 +1,9 @@ +# Links for torch and dgl +-f https://download.pytorch.org/whl/torch_stable.html + +pandas==1.4.* +gensim==4.2.0 +sentence-transformers==2.2.2 +dgl==1.0.* +torch==1.* +filelock==3.8.2 \ No newline at end of file diff --git a/cornac/models/hypar/requirements_cu116.txt b/cornac/models/hypar/requirements_cu116.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e4dad2b1b6b55559a13700aff1415a1d6002707 --- /dev/null +++ b/cornac/models/hypar/requirements_cu116.txt @@ -0,0 +1,10 @@ +# Links for torch and dgl +-f https://download.pytorch.org/whl/torch_stable.html +-f https://data.dgl.ai/wheels/cu116/repo.html + +pandas==1.4.* +gensim==4.2.0 +sentence-transformers==2.2.2 +dgl==1.0.* +torch==1.13.1+cu116 +filelock==3.8.2 \ No newline at end of file diff --git a/docs/source/api_ref/models.rst b/docs/source/api_ref/models.rst index e92c8881e23a37253968e18542a1a2c21b5a4810..2b2c599260e301a238b7b96e9f419b5d7e0f4691 100644 --- a/docs/source/api_ref/models.rst +++ b/docs/source/api_ref/models.rst @@ -39,6 +39,11 @@ Hybrid neural recommendation with joint deep representation learning of ratings .. automodule:: cornac.models.hrdr.recom_hrdr :members: +Hypergraphs with Attention on Reviews for Explainable Recommendation +-------------------------------------------------------------------------------------------------- +.. automodule:: cornac.models.hypar.recom_hypar + :members: + Simplifying and Powering Graph Convolution Network for Recommendation (LightGCN) -------------------------------------------------------------------------------- .. automodule:: cornac.models.lightgcn.recom_lightgcn