Skip to content
Snippets Groups Projects
Unverified Commit a470d5c8 authored by darrylong's avatar darrylong Committed by GitHub
Browse files

Optimize LightGCN Model (#531)


* Generated model base from LightGCN

* wip

* wip example

* add self-connection

* refactor code

* added sanity check

* Changed train batch size in example to 1024

* Updated readme for example folder

* Update Readme

* update docs

* Update block comment

* WIP

* Updated validation metric

* Updated message handling

* Added legacy lightgcn for comparison purposes

* Changed to follow 'a_k = 1/(k+1)',  k instead of i

* Changed early stopping technique to follow NGCF

* remove test_batchsize, early stop verbose to false

* Changed parameters to align with paper and ngcf

* refractor codes

* update docstring

* change param name to 'batch_size'

* Fix paper reference

---------

Co-authored-by: default avatartqtg <tuantq.vnu@gmail.com>
Co-authored-by: default avatarQuoc-Tuan Truong <tqtg@users.noreply.github.com>
parent c4849883
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
USER_KEY = "user"
ITEM_KEY = "item"
def construct_graph(data_set):
"""
Generates graph given a cornac data set
......@@ -14,89 +19,109 @@ def construct_graph(data_set):
The data set as provided by cornac
"""
user_indices, item_indices, _ = data_set.uir_tuple
user_nodes, item_nodes = (
torch.from_numpy(user_indices),
torch.from_numpy(
item_indices + data_set.total_users
), # increment item node idx by num users
)
u = torch.cat([user_nodes, item_nodes], dim=0)
v = torch.cat([item_nodes, user_nodes], dim=0)
data_dict = {
(USER_KEY, "user_item", ITEM_KEY): (user_indices, item_indices),
(ITEM_KEY, "item_user", USER_KEY): (item_indices, user_indices),
}
num_dict = {USER_KEY: data_set.total_users, ITEM_KEY: data_set.total_items}
g = dgl.graph((u, v), num_nodes=(data_set.total_users + data_set.total_items))
return g
return dgl.heterograph(data_dict, num_nodes_dict=num_dict)
class GCNLayer(nn.Module):
def __init__(self):
def __init__(self, norm_dict):
super(GCNLayer, self).__init__()
def forward(self, graph, src_embedding, dst_embedding):
with graph.local_scope():
inner_product = torch.cat((src_embedding, dst_embedding), dim=0)
out_degs = graph.out_degrees().to(src_embedding.device).float().clamp(min=1)
norm_out_degs = torch.pow(out_degs, -0.5).view(-1, 1) # D^-1/2
inner_product = inner_product * norm_out_degs
graph.ndata["h"] = inner_product
graph.update_all(
message_func=fn.copy_u("h", "m"), reduce_func=fn.sum("m", "h")
)
res = graph.ndata["h"]
in_degs = graph.in_degrees().to(src_embedding.device).float().clamp(min=1)
norm_in_degs = torch.pow(in_degs, -0.5).view(-1, 1) # D^-1/2
res = res * norm_in_degs
return res
# norm
self.norm_dict = norm_dict
def forward(self, g, feat_dict):
funcs = {} # message and reduce functions dict
# for each type of edges, compute messages and reduce them all
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
norm = self.norm_dict[(srctype, etype, dsttype)]
# TODO: CHECK HERE
messages = norm * feat_dict[srctype][src] # compute messages
g.edges[(srctype, etype, dsttype)].data[
etype
] = messages # store in edata
funcs[(srctype, etype, dsttype)] = (
fn.copy_e(etype, "m"),
fn.sum("m", "h"),
) # define message and reduce functions
g.multi_update_all(
funcs, "sum"
) # update all, reduce by first type-wisely then across different types
feature_dict = {}
for ntype in g.ntypes:
h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2) # l2 normalize
feature_dict[ntype] = h
return feature_dict
class Model(nn.Module):
def __init__(self, user_size, item_size, hidden_size, num_layers=3, device=None):
def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
super(Model, self).__init__()
self.user_size = user_size
self.item_size = item_size
self.hidden_size = hidden_size
self.embedding_weights = self._init_weights()
self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)])
self.norm_dict = dict()
self.lambda_reg = lambda_reg
self.device = device
def forward(self, graph):
user_embedding = self.embedding_weights["user_embedding"]
item_embedding = self.embedding_weights["item_embedding"]
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
dst_degree = g.in_degrees(
dst, etype=(srctype, etype, dsttype)
).float() # obtain degrees
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm
self.norm_dict[(srctype, etype, dsttype)] = norm
for i, layer in enumerate(self.layers, start=1):
if i == 1:
embeddings = layer(graph, user_embedding, item_embedding)
else:
embeddings = layer(
graph, embeddings[: self.user_size], embeddings[self.user_size:]
)
self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)])
user_embedding = user_embedding + embeddings[: self.user_size] * (
1 / (i + 1)
)
item_embedding = item_embedding + embeddings[self.user_size:] * (
1 / (i + 1)
)
return user_embedding, item_embedding
self.initializer = nn.init.xavier_uniform_
def _init_weights(self):
initializer = nn.init.xavier_uniform_
weights_dict = nn.ParameterDict(
# embeddings for different types of nodes
self.feature_dict = nn.ParameterDict(
{
"user_embedding": nn.Parameter(
initializer(torch.empty(self.user_size, self.hidden_size))
),
"item_embedding": nn.Parameter(
initializer(torch.empty(self.item_size, self.hidden_size))
),
ntype: nn.Parameter(
self.initializer(torch.empty(g.num_nodes(ntype), in_size))
)
for ntype in g.ntypes
}
)
return weights_dict
def forward(self, g, users=None, pos_items=None, neg_items=None):
h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
# obtain features of each layer and concatenate them all
user_embeds = h_dict[USER_KEY]
item_embeds = h_dict[ITEM_KEY]
for k, layer in enumerate(self.layers):
h_dict = layer(g, h_dict)
user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1))
item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1))
u_g_embeddings = user_embeds if users is None else user_embeds[users, :]
pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :]
neg_i_g_embeddings = item_embeds if neg_items is None else item_embeds[neg_items, :]
return u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings
def loss_fn(self, users, pos_items, neg_items):
pos_scores = (users * pos_items).sum(1)
neg_scores = (users * neg_items).sum(1)
bpr_loss = F.softplus(neg_scores - pos_scores).mean()
reg_loss = (
(1 / 2)
* (
torch.norm(users) ** 2
+ torch.norm(pos_items) ** 2
+ torch.norm(neg_items) ** 2
)
/ len(users)
)
return bpr_loss + self.lambda_reg * reg_loss, bpr_loss, reg_loss
......@@ -28,21 +28,18 @@ class LightGCN(Recommender):
name: string, default: 'LightGCN'
The name of the recommender model.
emb_size: int, default: 64
Size of the node embeddings.
num_epochs: int, default: 1000
Maximum number of iterations or the number of epochs
Maximum number of iterations or the number of epochs.
learning_rate: float, default: 0.001
The learning rate that determines the step size at each iteration
train_batch_size: int, default: 1024
batch_size: int, default: 1024
Mini-batch size used for train set
test_batch_size: int, default: 100
Mini-batch size used for test set
hidden_dim: int, default: 64
The embedding size of the model
num_layers: int, default: 3
Number of LightGCN Layers
......@@ -80,11 +77,10 @@ class LightGCN(Recommender):
def __init__(
self,
name="LightGCN",
emb_size=64,
num_epochs=1000,
learning_rate=0.001,
train_batch_size=1024,
test_batch_size=100,
hidden_dim=64,
batch_size=1024,
num_layers=3,
early_stopping=None,
lambda_reg=1e-4,
......@@ -93,13 +89,11 @@ class LightGCN(Recommender):
seed=2020,
):
super().__init__(name=name, trainable=trainable, verbose=verbose)
self.emb_size = emb_size
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.num_layers = num_layers
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.early_stopping = early_stopping
self.lambda_reg = lambda_reg
self.seed = seed
......@@ -135,19 +129,15 @@ class LightGCN(Recommender):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.seed)
graph = construct_graph(train_set).to(self.device)
model = Model(
train_set.total_users,
train_set.total_items,
self.hidden_dim,
graph,
self.emb_size,
self.num_layers,
self.lambda_reg,
).to(self.device)
graph = construct_graph(train_set).to(self.device)
optimizer = torch.optim.Adam(
model.parameters(), lr=self.learning_rate, weight_decay=self.lambda_reg
)
loss_fn = torch.nn.BCELoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
# model training
pbar = trange(
......@@ -163,35 +153,26 @@ class LightGCN(Recommender):
accum_loss = 0.0
for batch_u, batch_i, batch_j in tqdm(
train_set.uij_iter(
batch_size=self.train_batch_size,
batch_size=self.batch_size,
shuffle=True,
),
desc="Epoch",
total=train_set.num_batches(self.train_batch_size),
total=train_set.num_batches(self.batch_size),
leave=False,
position=1,
disable=not self.verbose,
):
user_embeddings, item_embeddings = model(graph)
batch_u = torch.from_numpy(batch_u).long().to(self.device)
batch_i = torch.from_numpy(batch_i).long().to(self.device)
batch_j = torch.from_numpy(batch_j).long().to(self.device)
user_embed = user_embeddings[batch_u]
positive_item_embed = item_embeddings[batch_i]
negative_item_embed = item_embeddings[batch_j]
ui_scores = (user_embed * positive_item_embed).sum(dim=1)
uj_scores = (user_embed * negative_item_embed).sum(dim=1)
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(
graph, batch_u, batch_i, batch_j
)
loss = loss_fn(
torch.sigmoid(ui_scores - uj_scores), torch.ones_like(ui_scores)
batch_loss, batch_bpr_loss, batch_reg_loss = model.loss_fn(
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings
)
accum_loss += loss.cpu().item()
accum_loss += batch_loss.cpu().item() * len(batch_u)
optimizer.zero_grad()
loss.backward()
batch_loss.backward()
optimizer.step()
accum_loss /= len(train_set.uir_tuple[0]) # normalize over all observations
......@@ -199,17 +180,16 @@ class LightGCN(Recommender):
# store user and item embedding matrices for prediction
model.eval()
self.U, self.V = model(graph)
u_embs, i_embs, _ = model(graph)
# we will use numpy for faster prediction in the score function, no need torch
self.U = u_embs.cpu().detach().numpy()
self.V = i_embs.cpu().detach().numpy()
if self.early_stopping is not None and self.early_stop(
**self.early_stopping
):
break
# we will use numpy for faster prediction in the score function, no need torch
self.U = self.U.cpu().detach().numpy()
self.V = self.V.cpu().detach().numpy()
def monitor_value(self):
"""Calculating monitored value used for early stopping on validation set (`val_set`).
This function will be called by `early_stop()` function.
......@@ -223,38 +203,17 @@ class LightGCN(Recommender):
if self.val_set is None:
return None
import torch
from ...metrics import Recall
from ...eval_methods import ranking_eval
loss_fn = torch.nn.BCELoss(reduction="sum")
accum_loss = 0.0
pbar = tqdm(
self.val_set.uij_iter(batch_size=self.test_batch_size),
desc="Validation",
total=self.val_set.num_batches(self.test_batch_size),
leave=False,
position=1,
disable=not self.verbose,
)
for batch_u, batch_i, batch_j in pbar:
batch_u = torch.from_numpy(batch_u).long().to(self.device)
batch_i = torch.from_numpy(batch_i).long().to(self.device)
batch_j = torch.from_numpy(batch_j).long().to(self.device)
user_embed = self.U[batch_u]
positive_item_embed = self.V[batch_i]
negative_item_embed = self.V[batch_j]
ui_scores = (user_embed * positive_item_embed).sum(dim=1)
uj_scores = (user_embed * negative_item_embed).sum(dim=1)
loss = loss_fn(
torch.sigmoid(ui_scores - uj_scores), torch.ones_like(ui_scores)
)
accum_loss += loss.cpu().item()
pbar.set_postfix(val_loss=accum_loss)
accum_loss /= len(self.val_set.uir_tuple[0])
return -accum_loss # higher is better -> smaller loss is better
recall_20 = ranking_eval(
model=self,
metrics=[Recall(k=20)],
train_set=self.train_set,
test_set=self.val_set
)[0][0]
return recall_20 # Section 4.1.2 in the paper, same strategy as NGCF.
def score(self, user_idx, item_idx=None):
"""Predict the scores/ratings of a user for an item.
......
......@@ -36,10 +36,10 @@ ratio_split = RatioSplit(
# Instantiate the LightGCN model
lightgcn = cornac.models.LightGCN(
seed=123,
num_epochs=2000,
num_epochs=1000,
num_layers=3,
early_stopping={"min_delta": 1e-4, "patience": 3},
train_batch_size=256,
early_stopping={"min_delta": 1e-4, "patience": 50},
batch_size=1024,
learning_rate=0.001,
lambda_reg=1e-4,
verbose=True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment