diff --git a/.gitignore b/.gitignore index 270230400f0e246644db8e05dcc6de126c2d96d6..01271755282894ddfe89bd70350a4da9587fbea7 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,10 @@ coverage.xml # Sphinx documentation docs/_build/ + + +# Environment +env +venv +.env +.venv diff --git a/cornac/models/lightgcn/lightgcn.py b/cornac/models/lightgcn/lightgcn.py index 6fcdb7f5c42041f1e8c9e962da1173a6bb9bf359..c2f0dba92e429f13c72e43700d3a1e3bd9de7fe4 100644 --- a/cornac/models/lightgcn/lightgcn.py +++ b/cornac/models/lightgcn/lightgcn.py @@ -1,10 +1,11 @@ +from typing import Union, List + import torch import torch.nn as nn import torch.nn.functional as F import dgl import dgl.function as fn - USER_KEY = "user" ITEM_KEY = "item" @@ -26,40 +27,38 @@ def construct_graph(data_set, total_users, total_items): } num_dict = {USER_KEY: total_users, ITEM_KEY: total_items} - return dgl.heterograph(data_dict, num_nodes_dict=num_dict) + g = dgl.heterograph(data_dict, num_nodes_dict=num_dict) + norm_dict = {} + for srctype, etype, dsttype in g.canonical_etypes: + src, dst = g.edges(etype=(srctype, etype, dsttype)) + dst_degree = g.in_degrees( + dst, etype=(srctype, etype, dsttype) + ).float() # obtain degrees + src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float() + norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm + g.edata['norm'] = {etype: norm} + + return g class GCNLayer(nn.Module): - def __init__(self, norm_dict): + def __init__(self): super(GCNLayer, self).__init__() - # norm - self.norm_dict = norm_dict - def forward(self, g, feat_dict): funcs = {} # message and reduce functions dict # for each type of edges, compute messages and reduce them all + g.ndata["h"] = feat_dict for srctype, etype, dsttype in g.canonical_etypes: - src, dst = g.edges(etype=(srctype, etype, dsttype)) - norm = self.norm_dict[(srctype, etype, dsttype)] - # TODO: CHECK HERE - messages = norm * feat_dict[srctype][src] # compute messages - g.edges[(srctype, etype, dsttype)].data[ - etype - ] = messages # store in edata funcs[(srctype, etype, dsttype)] = ( - fn.copy_e(etype, "m"), - fn.sum("m", "h"), + fn.u_mul_e("h", "norm", "m"), + fn.sum("m", "h_n"), ) # define message and reduce functions g.multi_update_all( funcs, "sum" ) # update all, reduce by first type-wisely then across different types - feature_dict = {} - for ntype in g.ntypes: - h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2) # l2 normalize - feature_dict[ntype] = h - return feature_dict + return g.dstdata["h_n"] class Model(nn.Module): @@ -69,16 +68,7 @@ class Model(nn.Module): self.lambda_reg = lambda_reg self.device = device - for srctype, etype, dsttype in g.canonical_etypes: - src, dst = g.edges(etype=(srctype, etype, dsttype)) - dst_degree = g.in_degrees( - dst, etype=(srctype, etype, dsttype) - ).float() # obtain degrees - src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float() - norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm - self.norm_dict[(srctype, etype, dsttype)] = norm - - self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)]) + self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)]) self.initializer = nn.init.xavier_uniform_ @@ -92,16 +82,37 @@ class Model(nn.Module): } ) - def forward(self, g, users=None, pos_items=None, neg_items=None): - h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes} - # obtain features of each layer and concatenate them all - user_embeds = h_dict[USER_KEY] - item_embeds = h_dict[ITEM_KEY] + def forward(self, in_g: Union[dgl.DGLGraph, List[dgl.DGLGraph]], users=None, pos_items=None, neg_items=None): + + if isinstance(in_g, list): + h_dict = {ntype: self.feature_dict[ntype][in_g[0].ndata[dgl.NID][ntype]] for ntype in in_g[0].ntypes} + user_embeds = h_dict[USER_KEY][in_g[-1].dstnodes(USER_KEY)] + item_embeds = h_dict[ITEM_KEY][in_g[-1].dstnodes(ITEM_KEY)] + iterator = enumerate(zip(in_g, self.layers)) + else: + h_dict = {ntype: self.feature_dict[ntype] for ntype in in_g.ntypes} + # obtain features of each layer and concatenate them all + user_embeds = h_dict[USER_KEY] + item_embeds = h_dict[ITEM_KEY] + iterator = enumerate(zip([in_g] * len(self.layers), self.layers)) + + user_embeds = user_embeds * (1 / (len(self.layers) + 1)) + item_embeds = item_embeds * (1 / (len(self.layers) + 1)) - for k, layer in enumerate(self.layers): + for k, (g, layer) in iterator: h_dict = layer(g, h_dict) - user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1)) - item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1)) + ue = h_dict[USER_KEY] + ie = h_dict[ITEM_KEY] + + if isinstance(in_g, list): + ue = ue[in_g[-1].dstnodes(USER_KEY)] + ie = ie[in_g[-1].dstnodes(ITEM_KEY)] + + user_embeds = user_embeds + ue + item_embeds = item_embeds + ie + + user_embeds = user_embeds / (len(self.layers) + 1) + item_embeds = item_embeds / (len(self.layers) + 1) u_g_embeddings = user_embeds if users is None else user_embeds[users, :] pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :] diff --git a/cornac/models/lightgcn/requirements.txt b/cornac/models/lightgcn/requirements.txt index 32f294fbcf7bff6f6fc5f7cfc8926fdbb78f4499..fac360e63836eb06d45ee71d4b3c37156764141f 100644 --- a/cornac/models/lightgcn/requirements.txt +++ b/cornac/models/lightgcn/requirements.txt @@ -1,2 +1,4 @@ -torch>=2.0.0 -dgl>=1.1.0 \ No newline at end of file +# Comment in to use cuda 11.X +#-f https://data.dgl.ai/wheels/cu11X/repo.html +torch==2.0.0 +dgl==1.1.0 \ No newline at end of file