From 2b4f2b481b1a5f94e6df540de7678a7184bf035d Mon Sep 17 00:00:00 2001
From: Athmane Mansour Bahar <ja_mansourbahar@esi.dz>
Date: Thu, 15 Aug 2024 17:46:52 +0000
Subject: [PATCH] Upload New File

---
 utils/poolers.py | 43 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 43 insertions(+)
 create mode 100644 utils/poolers.py

diff --git a/utils/poolers.py b/utils/poolers.py
new file mode 100644
index 0000000..4b9dc4a
--- /dev/null
+++ b/utils/poolers.py
@@ -0,0 +1,43 @@
+import torch.nn as nn
+
+
+class Pooling(nn.Module):
+    def __init__(self, pooler):
+        super(Pooling, self).__init__()
+        self.pooler = pooler
+
+    def forward(self, graph, feat, t=None):
+        feat = feat
+        # Implement node type-specific pooling
+        with graph.local_scope():
+            if t is None:
+                if self.pooler == 'mean':
+                    return feat.mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat.sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat.max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
+            elif isinstance(t, int):
+                mask = (graph.ndata['type'] == t)
+                if self.pooler == 'mean':
+                    return feat[mask].mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat[mask].sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat[mask].max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
+            else:
+                mask = (graph.ndata['type'] == t[0])
+                for i in range(1, len(t)):
+                    mask |= (graph.ndata['type'] == t[i])
+                if self.pooler == 'mean':
+                    return feat[mask].mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat[mask].sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat[mask].max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
-- 
GitLab