From a2ee37e3919564a38d308507b564de5c5856782d Mon Sep 17 00:00:00 2001 From: Theis Jendal <jendal@live.dk> Date: Mon, 18 Mar 2024 10:50:16 +0100 Subject: [PATCH] Lightgcn fix (#602) * Add git ignore * Lightgcn fix Removed normalization for layers, not used for lgcn. Fixed sum weight constant to num layers instead of cur layer index. Allow lgcn to take blocks. Fixed requirement error caused by newer dgl versions. Moved edge normalization to graph for easier use. * Lightgcn debug error fix * Simplified layer normalization and readability * Easier support of rcuda --- .gitignore | 7 ++ cornac/models/lightgcn/lightgcn.py | 87 ++++++++++++++----------- cornac/models/lightgcn/requirements.txt | 6 +- 3 files changed, 60 insertions(+), 40 deletions(-) diff --git a/.gitignore b/.gitignore index 27023040..01271755 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,10 @@ coverage.xml # Sphinx documentation docs/_build/ + + +# Environment +env +venv +.env +.venv diff --git a/cornac/models/lightgcn/lightgcn.py b/cornac/models/lightgcn/lightgcn.py index 6fcdb7f5..c2f0dba9 100644 --- a/cornac/models/lightgcn/lightgcn.py +++ b/cornac/models/lightgcn/lightgcn.py @@ -1,10 +1,11 @@ +from typing import Union, List + import torch import torch.nn as nn import torch.nn.functional as F import dgl import dgl.function as fn - USER_KEY = "user" ITEM_KEY = "item" @@ -26,40 +27,38 @@ def construct_graph(data_set, total_users, total_items): } num_dict = {USER_KEY: total_users, ITEM_KEY: total_items} - return dgl.heterograph(data_dict, num_nodes_dict=num_dict) + g = dgl.heterograph(data_dict, num_nodes_dict=num_dict) + norm_dict = {} + for srctype, etype, dsttype in g.canonical_etypes: + src, dst = g.edges(etype=(srctype, etype, dsttype)) + dst_degree = g.in_degrees( + dst, etype=(srctype, etype, dsttype) + ).float() # obtain degrees + src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float() + norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm + g.edata['norm'] = {etype: norm} + + return g class GCNLayer(nn.Module): - def __init__(self, norm_dict): + def __init__(self): super(GCNLayer, self).__init__() - # norm - self.norm_dict = norm_dict - def forward(self, g, feat_dict): funcs = {} # message and reduce functions dict # for each type of edges, compute messages and reduce them all + g.ndata["h"] = feat_dict for srctype, etype, dsttype in g.canonical_etypes: - src, dst = g.edges(etype=(srctype, etype, dsttype)) - norm = self.norm_dict[(srctype, etype, dsttype)] - # TODO: CHECK HERE - messages = norm * feat_dict[srctype][src] # compute messages - g.edges[(srctype, etype, dsttype)].data[ - etype - ] = messages # store in edata funcs[(srctype, etype, dsttype)] = ( - fn.copy_e(etype, "m"), - fn.sum("m", "h"), + fn.u_mul_e("h", "norm", "m"), + fn.sum("m", "h_n"), ) # define message and reduce functions g.multi_update_all( funcs, "sum" ) # update all, reduce by first type-wisely then across different types - feature_dict = {} - for ntype in g.ntypes: - h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2) # l2 normalize - feature_dict[ntype] = h - return feature_dict + return g.dstdata["h_n"] class Model(nn.Module): @@ -69,16 +68,7 @@ class Model(nn.Module): self.lambda_reg = lambda_reg self.device = device - for srctype, etype, dsttype in g.canonical_etypes: - src, dst = g.edges(etype=(srctype, etype, dsttype)) - dst_degree = g.in_degrees( - dst, etype=(srctype, etype, dsttype) - ).float() # obtain degrees - src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float() - norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm - self.norm_dict[(srctype, etype, dsttype)] = norm - - self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)]) + self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)]) self.initializer = nn.init.xavier_uniform_ @@ -92,16 +82,37 @@ class Model(nn.Module): } ) - def forward(self, g, users=None, pos_items=None, neg_items=None): - h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes} - # obtain features of each layer and concatenate them all - user_embeds = h_dict[USER_KEY] - item_embeds = h_dict[ITEM_KEY] + def forward(self, in_g: Union[dgl.DGLGraph, List[dgl.DGLGraph]], users=None, pos_items=None, neg_items=None): + + if isinstance(in_g, list): + h_dict = {ntype: self.feature_dict[ntype][in_g[0].ndata[dgl.NID][ntype]] for ntype in in_g[0].ntypes} + user_embeds = h_dict[USER_KEY][in_g[-1].dstnodes(USER_KEY)] + item_embeds = h_dict[ITEM_KEY][in_g[-1].dstnodes(ITEM_KEY)] + iterator = enumerate(zip(in_g, self.layers)) + else: + h_dict = {ntype: self.feature_dict[ntype] for ntype in in_g.ntypes} + # obtain features of each layer and concatenate them all + user_embeds = h_dict[USER_KEY] + item_embeds = h_dict[ITEM_KEY] + iterator = enumerate(zip([in_g] * len(self.layers), self.layers)) + + user_embeds = user_embeds * (1 / (len(self.layers) + 1)) + item_embeds = item_embeds * (1 / (len(self.layers) + 1)) - for k, layer in enumerate(self.layers): + for k, (g, layer) in iterator: h_dict = layer(g, h_dict) - user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1)) - item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1)) + ue = h_dict[USER_KEY] + ie = h_dict[ITEM_KEY] + + if isinstance(in_g, list): + ue = ue[in_g[-1].dstnodes(USER_KEY)] + ie = ie[in_g[-1].dstnodes(ITEM_KEY)] + + user_embeds = user_embeds + ue + item_embeds = item_embeds + ie + + user_embeds = user_embeds / (len(self.layers) + 1) + item_embeds = item_embeds / (len(self.layers) + 1) u_g_embeddings = user_embeds if users is None else user_embeds[users, :] pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :] diff --git a/cornac/models/lightgcn/requirements.txt b/cornac/models/lightgcn/requirements.txt index 32f294fb..fac360e6 100644 --- a/cornac/models/lightgcn/requirements.txt +++ b/cornac/models/lightgcn/requirements.txt @@ -1,2 +1,4 @@ -torch>=2.0.0 -dgl>=1.1.0 \ No newline at end of file +# Comment in to use cuda 11.X +#-f https://data.dgl.ai/wheels/cu11X/repo.html +torch==2.0.0 +dgl==1.1.0 \ No newline at end of file -- GitLab