From 2b4f2b481b1a5f94e6df540de7678a7184bf035d Mon Sep 17 00:00:00 2001 From: Athmane Mansour Bahar <ja_mansourbahar@esi.dz> Date: Thu, 15 Aug 2024 17:46:52 +0000 Subject: [PATCH] Upload New File --- utils/poolers.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 utils/poolers.py diff --git a/utils/poolers.py b/utils/poolers.py new file mode 100644 index 0000000..4b9dc4a --- /dev/null +++ b/utils/poolers.py @@ -0,0 +1,43 @@ +import torch.nn as nn + + +class Pooling(nn.Module): + def __init__(self, pooler): + super(Pooling, self).__init__() + self.pooler = pooler + + def forward(self, graph, feat, t=None): + feat = feat + # Implement node type-specific pooling + with graph.local_scope(): + if t is None: + if self.pooler == 'mean': + return feat.mean(0, keepdim=True) + elif self.pooler == 'sum': + return feat.sum(0, keepdim=True) + elif self.pooler == 'max': + return feat.max(0, keepdim=True) + else: + raise NotImplementedError + elif isinstance(t, int): + mask = (graph.ndata['type'] == t) + if self.pooler == 'mean': + return feat[mask].mean(0, keepdim=True) + elif self.pooler == 'sum': + return feat[mask].sum(0, keepdim=True) + elif self.pooler == 'max': + return feat[mask].max(0, keepdim=True) + else: + raise NotImplementedError + else: + mask = (graph.ndata['type'] == t[0]) + for i in range(1, len(t)): + mask |= (graph.ndata['type'] == t[i]) + if self.pooler == 'mean': + return feat[mask].mean(0, keepdim=True) + elif self.pooler == 'sum': + return feat[mask].sum(0, keepdim=True) + elif self.pooler == 'max': + return feat[mask].max(0, keepdim=True) + else: + raise NotImplementedError -- GitLab