From 9f8e0a984ccee9a2ad300c8fc6f221346e22c047 Mon Sep 17 00:00:00 2001 From: Athmane Mansour Bahar <ja_mansourbahar@esi.dz> Date: Thu, 15 Aug 2024 17:50:40 +0000 Subject: [PATCH] Upload New File --- trainer/utils/poolers.py | 43 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 trainer/utils/poolers.py diff --git a/trainer/utils/poolers.py b/trainer/utils/poolers.py new file mode 100644 index 0000000..4b9dc4a --- /dev/null +++ b/trainer/utils/poolers.py @@ -0,0 +1,43 @@ +import torch.nn as nn + + +class Pooling(nn.Module): + def __init__(self, pooler): + super(Pooling, self).__init__() + self.pooler = pooler + + def forward(self, graph, feat, t=None): + feat = feat + # Implement node type-specific pooling + with graph.local_scope(): + if t is None: + if self.pooler == 'mean': + return feat.mean(0, keepdim=True) + elif self.pooler == 'sum': + return feat.sum(0, keepdim=True) + elif self.pooler == 'max': + return feat.max(0, keepdim=True) + else: + raise NotImplementedError + elif isinstance(t, int): + mask = (graph.ndata['type'] == t) + if self.pooler == 'mean': + return feat[mask].mean(0, keepdim=True) + elif self.pooler == 'sum': + return feat[mask].sum(0, keepdim=True) + elif self.pooler == 'max': + return feat[mask].max(0, keepdim=True) + else: + raise NotImplementedError + else: + mask = (graph.ndata['type'] == t[0]) + for i in range(1, len(t)): + mask |= (graph.ndata['type'] == t[i]) + if self.pooler == 'mean': + return feat[mask].mean(0, keepdim=True) + elif self.pooler == 'sum': + return feat[mask].sum(0, keepdim=True) + elif self.pooler == 'max': + return feat[mask].max(0, keepdim=True) + else: + raise NotImplementedError -- GitLab