Skip to content
Snippets Groups Projects
Unverified Commit a2ee37e3 authored by Theis Jendal's avatar Theis Jendal Committed by GitHub
Browse files

Lightgcn fix (#602)

* Add git ignore

* Lightgcn fix

Removed normalization for layers, not used for lgcn.
Fixed sum weight constant to num layers instead of cur layer index.
Allow lgcn to take blocks.
Fixed requirement error caused by newer dgl versions.
Moved edge normalization to graph for easier use.

* Lightgcn debug error fix

* Simplified layer normalization and readability

* Easier support of rcuda
parent c6035980
No related branches found
No related tags found
No related merge requests found
......@@ -55,3 +55,10 @@ coverage.xml
# Sphinx documentation
docs/_build/
# Environment
env
venv
.env
.venv
from typing import Union, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
USER_KEY = "user"
ITEM_KEY = "item"
......@@ -26,40 +27,38 @@ def construct_graph(data_set, total_users, total_items):
}
num_dict = {USER_KEY: total_users, ITEM_KEY: total_items}
return dgl.heterograph(data_dict, num_nodes_dict=num_dict)
g = dgl.heterograph(data_dict, num_nodes_dict=num_dict)
norm_dict = {}
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
dst_degree = g.in_degrees(
dst, etype=(srctype, etype, dsttype)
).float() # obtain degrees
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm
g.edata['norm'] = {etype: norm}
return g
class GCNLayer(nn.Module):
def __init__(self, norm_dict):
def __init__(self):
super(GCNLayer, self).__init__()
# norm
self.norm_dict = norm_dict
def forward(self, g, feat_dict):
funcs = {} # message and reduce functions dict
# for each type of edges, compute messages and reduce them all
g.ndata["h"] = feat_dict
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
norm = self.norm_dict[(srctype, etype, dsttype)]
# TODO: CHECK HERE
messages = norm * feat_dict[srctype][src] # compute messages
g.edges[(srctype, etype, dsttype)].data[
etype
] = messages # store in edata
funcs[(srctype, etype, dsttype)] = (
fn.copy_e(etype, "m"),
fn.sum("m", "h"),
fn.u_mul_e("h", "norm", "m"),
fn.sum("m", "h_n"),
) # define message and reduce functions
g.multi_update_all(
funcs, "sum"
) # update all, reduce by first type-wisely then across different types
feature_dict = {}
for ntype in g.ntypes:
h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2) # l2 normalize
feature_dict[ntype] = h
return feature_dict
return g.dstdata["h_n"]
class Model(nn.Module):
......@@ -69,16 +68,7 @@ class Model(nn.Module):
self.lambda_reg = lambda_reg
self.device = device
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
dst_degree = g.in_degrees(
dst, etype=(srctype, etype, dsttype)
).float() # obtain degrees
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm
self.norm_dict[(srctype, etype, dsttype)] = norm
self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)])
self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)])
self.initializer = nn.init.xavier_uniform_
......@@ -92,16 +82,37 @@ class Model(nn.Module):
}
)
def forward(self, g, users=None, pos_items=None, neg_items=None):
h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
# obtain features of each layer and concatenate them all
user_embeds = h_dict[USER_KEY]
item_embeds = h_dict[ITEM_KEY]
def forward(self, in_g: Union[dgl.DGLGraph, List[dgl.DGLGraph]], users=None, pos_items=None, neg_items=None):
if isinstance(in_g, list):
h_dict = {ntype: self.feature_dict[ntype][in_g[0].ndata[dgl.NID][ntype]] for ntype in in_g[0].ntypes}
user_embeds = h_dict[USER_KEY][in_g[-1].dstnodes(USER_KEY)]
item_embeds = h_dict[ITEM_KEY][in_g[-1].dstnodes(ITEM_KEY)]
iterator = enumerate(zip(in_g, self.layers))
else:
h_dict = {ntype: self.feature_dict[ntype] for ntype in in_g.ntypes}
# obtain features of each layer and concatenate them all
user_embeds = h_dict[USER_KEY]
item_embeds = h_dict[ITEM_KEY]
iterator = enumerate(zip([in_g] * len(self.layers), self.layers))
user_embeds = user_embeds * (1 / (len(self.layers) + 1))
item_embeds = item_embeds * (1 / (len(self.layers) + 1))
for k, layer in enumerate(self.layers):
for k, (g, layer) in iterator:
h_dict = layer(g, h_dict)
user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1))
item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1))
ue = h_dict[USER_KEY]
ie = h_dict[ITEM_KEY]
if isinstance(in_g, list):
ue = ue[in_g[-1].dstnodes(USER_KEY)]
ie = ie[in_g[-1].dstnodes(ITEM_KEY)]
user_embeds = user_embeds + ue
item_embeds = item_embeds + ie
user_embeds = user_embeds / (len(self.layers) + 1)
item_embeds = item_embeds / (len(self.layers) + 1)
u_g_embeddings = user_embeds if users is None else user_embeds[users, :]
pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :]
......
torch>=2.0.0
dgl>=1.1.0
\ No newline at end of file
# Comment in to use cuda 11.X
#-f https://data.dgl.ai/wheels/cu11X/repo.html
torch==2.0.0
dgl==1.1.0
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment