Skip to content
Snippets Groups Projects

final push

Merged Mohamed Yacine Touahria Miliani requested to merge master into main
160 files
+ 9677
73
Compare changes
  • Side-by-side
  • Inline
Files
160
+ 56
0
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from Ghypeddings.H2HGCN.utils import *
class CentroidDistance(nn.Module):
"""
Implement a model that calculates the pairwise distances between node representations
and centroids
"""
def __init__(self, args, logger, manifold):
super(CentroidDistance, self).__init__()
self.args = args
self.logger = logger
self.manifold = manifold
self.debug = False
# centroid embedding
self.centroid_embedding = nn.Embedding(
args.num_centroid, args.dim,
sparse=False,
scale_grad_by_freq=False,
)
nn_init(self.centroid_embedding, self.args.proj_init)
args.eucl_vars.append(self.centroid_embedding)
def forward(self, node_repr, mask):
"""
Args:
node_repr: [node_num, dim]
mask: [node_num, 1] 1 denote real node, 0 padded node
return:
graph_centroid_dist: [1, num_centroid]
node_centroid_dist: [1, node_num, num_centroid]
"""
node_num = node_repr.size(0)
# broadcast and reshape node_repr to [node_num * num_centroid, dim]
node_repr = node_repr.unsqueeze(1).expand(
-1,
self.args.num_centroid,
-1).contiguous().view(-1, self.args.dim)
# broadcast and reshape centroid embeddings to [node_num * num_centroid, dim]
centroid_repr = self.manifold.exp_map_zero(self.centroid_embedding(th.arange(self.args.num_centroid).cuda().to(self.args.device)))
centroid_repr = centroid_repr.unsqueeze(0).expand(
node_num,
-1,
-1).contiguous().view(-1, self.args.dim)
# get distance
node_centroid_dist = self.manifold.distance(node_repr, centroid_repr)
node_centroid_dist = node_centroid_dist.view(1, node_num, self.args.num_centroid)
# average pooling over nodes
graph_centroid_dist = th.sum(node_centroid_dist, dim=1) / th.sum(mask)
return graph_centroid_dist, node_centroid_dist
Loading