diff --git a/.gitignore b/.gitignore
index 270230400f0e246644db8e05dcc6de126c2d96d6..01271755282894ddfe89bd70350a4da9587fbea7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -55,3 +55,10 @@ coverage.xml
 
 # Sphinx documentation
 docs/_build/
+
+
+# Environment
+env
+venv
+.env
+.venv
diff --git a/cornac/models/lightgcn/lightgcn.py b/cornac/models/lightgcn/lightgcn.py
index 6fcdb7f5c42041f1e8c9e962da1173a6bb9bf359..c2f0dba92e429f13c72e43700d3a1e3bd9de7fe4 100644
--- a/cornac/models/lightgcn/lightgcn.py
+++ b/cornac/models/lightgcn/lightgcn.py
@@ -1,10 +1,11 @@
+from typing import Union, List
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import dgl
 import dgl.function as fn
 
-
 USER_KEY = "user"
 ITEM_KEY = "item"
 
@@ -26,40 +27,38 @@ def construct_graph(data_set, total_users, total_items):
     }
     num_dict = {USER_KEY: total_users, ITEM_KEY: total_items}
 
-    return dgl.heterograph(data_dict, num_nodes_dict=num_dict)
+    g = dgl.heterograph(data_dict, num_nodes_dict=num_dict)
+    norm_dict = {}
+    for srctype, etype, dsttype in g.canonical_etypes:
+        src, dst = g.edges(etype=(srctype, etype, dsttype))
+        dst_degree = g.in_degrees(
+            dst, etype=(srctype, etype, dsttype)
+        ).float()  # obtain degrees
+        src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
+        norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1)  # compute norm
+        g.edata['norm'] = {etype: norm}
+
+    return g
 
 
 class GCNLayer(nn.Module):
-    def __init__(self, norm_dict):
+    def __init__(self):
         super(GCNLayer, self).__init__()
 
-        # norm
-        self.norm_dict = norm_dict
-
     def forward(self, g, feat_dict):
         funcs = {}  # message and reduce functions dict
         # for each type of edges, compute messages and reduce them all
+        g.ndata["h"] = feat_dict
         for srctype, etype, dsttype in g.canonical_etypes:
-            src, dst = g.edges(etype=(srctype, etype, dsttype))
-            norm = self.norm_dict[(srctype, etype, dsttype)]
-            # TODO: CHECK HERE
-            messages = norm * feat_dict[srctype][src]  # compute messages
-            g.edges[(srctype, etype, dsttype)].data[
-                etype
-            ] = messages  # store in edata
             funcs[(srctype, etype, dsttype)] = (
-                fn.copy_e(etype, "m"),
-                fn.sum("m", "h"),
+                fn.u_mul_e("h", "norm", "m"),
+                fn.sum("m", "h_n"),
             )  # define message and reduce functions
 
         g.multi_update_all(
             funcs, "sum"
         )  # update all, reduce by first type-wisely then across different types
-        feature_dict = {}
-        for ntype in g.ntypes:
-            h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2)  # l2 normalize
-            feature_dict[ntype] = h
-        return feature_dict
+        return g.dstdata["h_n"]
 
 
 class Model(nn.Module):
@@ -69,16 +68,7 @@ class Model(nn.Module):
         self.lambda_reg = lambda_reg
         self.device = device
 
-        for srctype, etype, dsttype in g.canonical_etypes:
-            src, dst = g.edges(etype=(srctype, etype, dsttype))
-            dst_degree = g.in_degrees(
-                dst, etype=(srctype, etype, dsttype)
-            ).float()  # obtain degrees
-            src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
-            norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1)  # compute norm
-            self.norm_dict[(srctype, etype, dsttype)] = norm
-
-        self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)])
+        self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)])
 
         self.initializer = nn.init.xavier_uniform_
 
@@ -92,16 +82,37 @@ class Model(nn.Module):
             }
         )
 
-    def forward(self, g, users=None, pos_items=None, neg_items=None):
-        h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
-        # obtain features of each layer and concatenate them all
-        user_embeds = h_dict[USER_KEY]
-        item_embeds = h_dict[ITEM_KEY]
+    def forward(self, in_g: Union[dgl.DGLGraph, List[dgl.DGLGraph]], users=None, pos_items=None, neg_items=None):
+
+        if isinstance(in_g, list):
+            h_dict = {ntype: self.feature_dict[ntype][in_g[0].ndata[dgl.NID][ntype]] for ntype in in_g[0].ntypes}
+            user_embeds = h_dict[USER_KEY][in_g[-1].dstnodes(USER_KEY)]
+            item_embeds = h_dict[ITEM_KEY][in_g[-1].dstnodes(ITEM_KEY)]
+            iterator = enumerate(zip(in_g, self.layers))
+        else:
+            h_dict = {ntype: self.feature_dict[ntype] for ntype in in_g.ntypes}
+            # obtain features of each layer and concatenate them all
+            user_embeds = h_dict[USER_KEY]
+            item_embeds = h_dict[ITEM_KEY]
+            iterator = enumerate(zip([in_g] * len(self.layers), self.layers))
+
+        user_embeds = user_embeds * (1 / (len(self.layers) + 1))
+        item_embeds = item_embeds * (1 / (len(self.layers) + 1))
 
-        for k, layer in enumerate(self.layers):
+        for k, (g, layer) in iterator:
             h_dict = layer(g, h_dict)
-            user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1))
-            item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1))
+            ue = h_dict[USER_KEY]
+            ie = h_dict[ITEM_KEY]
+
+            if isinstance(in_g, list):
+                ue = ue[in_g[-1].dstnodes(USER_KEY)]
+                ie = ie[in_g[-1].dstnodes(ITEM_KEY)]
+                
+            user_embeds = user_embeds + ue
+            item_embeds = item_embeds + ie
+
+        user_embeds = user_embeds / (len(self.layers) + 1)
+        item_embeds = item_embeds / (len(self.layers) + 1)
 
         u_g_embeddings = user_embeds if users is None else user_embeds[users, :]
         pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :]
diff --git a/cornac/models/lightgcn/requirements.txt b/cornac/models/lightgcn/requirements.txt
index 32f294fbcf7bff6f6fc5f7cfc8926fdbb78f4499..fac360e63836eb06d45ee71d4b3c37156764141f 100644
--- a/cornac/models/lightgcn/requirements.txt
+++ b/cornac/models/lightgcn/requirements.txt
@@ -1,2 +1,4 @@
-torch>=2.0.0
-dgl>=1.1.0
\ No newline at end of file
+# Comment in to use cuda 11.X
+#-f https://data.dgl.ai/wheels/cu11X/repo.html
+torch==2.0.0
+dgl==1.1.0
\ No newline at end of file