Skip to content
Snippets Groups Projects
Commit 7725415a authored by yacinetouahria's avatar yacinetouahria
Browse files

final push

parents
No related branches found
No related tags found
1 merge request!1final push
Showing
with 1676 additions and 0 deletions
from collections import defaultdict
import os
import pickle
import json
import torch.nn as nn
import torch as th
import torch.optim as optim
import numpy as np
import random
from Ghypeddings.H2HGCN.optimizers.rsgd import RiemannianSGD
import math
import subprocess
import random
def set_seed(seed):
"""
Set the random seed
"""
random.seed(seed)
np.random.seed(seed)
th.manual_seed(seed)
th.cuda.manual_seed(seed)
th.cuda.manual_seed_all(seed)
def th_dot(x, y, keepdim=True):
return th.sum(x * y, dim=1, keepdim=keepdim)
def pad_sequence(data_list, maxlen, value=0):
return [row + [value] * (maxlen - len(row)) for row in data_list]
def normalize_weight(adj_mat, weight):
degree = [1 / math.sqrt(sum(np.abs(w))) for w in weight]
for dst in range(len(adj_mat)):
for src_idx in range(len(adj_mat[dst])):
src = adj_mat[dst][src_idx]
weight[dst][src_idx] = degree[dst] * weight[dst][src_idx] * degree[src]
def nn_init(nn_module, method='orthogonal'):
"""
Initialize a Sequential or Module object
Args:
nn_module: Sequential or Module
method: initialization method
"""
if method == 'none':
return
for param_name, _ in nn_module.named_parameters():
if isinstance(nn_module, nn.Sequential):
# for a Sequential object, the param_name contains both id and param name
i, name = param_name.split('.', 1)
param = getattr(nn_module[int(i)], name)
else:
param = getattr(nn_module, param_name)
if param_name.find('weight') > -1:
init_weight(param, method)
elif param_name.find('bias') > -1:
nn.init.uniform_(param, -1e-4, 1e-4)
def get_params(params_list, vars_list):
"""
Add parameters in vars_list to param_list
"""
for i in vars_list:
if issubclass(i.__class__, nn.Module):
params_list.extend(list(i.parameters()))
elif issubclass(i.__class__, nn.Parameter):
params_list.append(i)
else:
print("Encounter unknown objects")
exit(1)
def categorize_params(args):
"""
Categorize parameters into hyperbolic ones and euclidean ones
"""
stiefel_params, euclidean_params = [], []
get_params(euclidean_params, args.eucl_vars)
get_params(stiefel_params, args.stie_vars)
return stiefel_params, euclidean_params
def get_activation(args):
if args.activation == 'leaky_relu':
return nn.LeakyReLU(args.leaky_relu)
elif args.activation == 'rrelu':
return nn.RReLU()
elif args.activation == 'relu':
return nn.ReLU()
elif args.activation == 'elu':
return nn.ELU()
elif args.activation == 'prelu':
return nn.PReLU()
elif args.activation == 'selu':
return nn.SELU()
def init_weight(weight, method):
"""
Initialize parameters
Args:
weight: a Parameter object
method: initialization method
"""
if method == 'orthogonal':
nn.init.orthogonal_(weight)
elif method == 'xavier':
nn.init.xavier_uniform_(weight)
elif method == 'kaiming':
nn.init.kaiming_uniform_(weight)
elif method == 'none':
pass
else:
raise Exception('Unknown init method')
def get_stiefel_optimizer(args, params, lr_stie):
if args.stiefel_optimizer == 'rsgd':
optimizer = RiemannianSGD(
args,
params,
lr=lr_stie,
)
elif args.stiefel_optimizer == 'ramsgrad':
optimizer = RiemannianAMSGrad(
args,
params,
lr=lr_stie,
)
else:
print("unsupported hyper optimizer")
exit(1)
return optimizer
def get_lr_scheduler(args, optimizer):
if args.lr_scheduler == 'exponential':
return optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
elif args.lr_scheduler == 'cosine':
return optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=0)
elif args.lr_scheduler == 'cycle':
return optim.lr_scheduler.CyclicLR(optimizer, 0, max_lr=args.lr, step_size_up=20, cycle_momentum=False)
elif args.lr_scheduler == 'step':
return optim.lr_scheduler.StepLR(
optimizer,
step_size=int(args.step_lr_reduce_freq),
gamma=float(args.step_lr_gamma)
)
elif args.lr_scheduler == 'none':
return NoneScheduler()
def get_optimizer(args, params, lr):
if args.optimizer == 'sgd':
optimizer = optim.SGD(params, lr=lr, weight_decay=args.weight_decay)
elif args.optimizer == 'Adam':
optimizer = optim.Adam(params, lr=lr, weight_decay=args.weight_decay)
elif args.optimizer == 'amsgrad':
optimizer = optim.Adam(params, lr=lr, amsgrad=True, weight_decay=args.weight_decay)
return optimizer
def set_up_optimizer_scheduler(hyperbolic, args, model, lr, lr_stie, pprint=True):
stiefel_params, euclidean_params = categorize_params(args)
#assert(len(list(model.parameters())) == len(stiefel_params) + len(euclidean_params))
optimizer = get_optimizer(args, euclidean_params, lr)
lr_scheduler = get_lr_scheduler(args, optimizer)
if len(stiefel_params) > 0:
stiefel_optimizer = get_stiefel_optimizer(args, stiefel_params, lr_stie)
stiefel_lr_scheduler = get_lr_scheduler(args, stiefel_optimizer)
else:
stiefel_optimizer, stiefel_lr_scheduler = None, None
return optimizer, lr_scheduler, stiefel_optimizer, stiefel_lr_scheduler
\ No newline at end of file
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.modules.loss
import argparse
def format_metrics(metrics, split):
"""Format metric in metric dict for logging."""
return " ".join(
["{}_{}: {:.4f}".format(split, metric_name, metric_val) for metric_name, metric_val in metrics.items()])
def create_args(*args):
parser = argparse.ArgumentParser()
parser.add_argument('--dim', type=int, default=args[0])
parser.add_argument('--c', type=int, default=args[1])
parser.add_argument('--num_layers', type=int, default=args[2])
parser.add_argument('--bias', type=bool, default=args[3])
parser.add_argument('--act', type=str, default=args[4])
parser.add_argument('--select_manifold', type=str, default=args[5])
parser.add_argument('--num_centroid', type=int, default=args[6])
parser.add_argument('--lr_stie', type=float, default=args[7])
parser.add_argument('--stie_vars', nargs='+', default=args[8])
parser.add_argument('--stiefel_optimizer', type=str, default=args[9])
parser.add_argument('--eucl_vars', nargs='+', default=args[10])
parser.add_argument('--grad_clip', type=float, default=args[11])
parser.add_argument('--optimizer', type=str, default=args[12])
parser.add_argument('--weight_decay', type=float, default=args[13])
parser.add_argument('--lr', type=float, default=args[14])
parser.add_argument('--lr_scheduler', type=str, default=args[15])
parser.add_argument('--lr_gamma', type=float, default=args[16])
parser.add_argument('--step_lr_gamma', type=float, default=args[17])
parser.add_argument('--step_lr_reduce_freq', type=int, default=args[18])
parser.add_argument('--proj_init', type=str, default=args[19])
parser.add_argument('--tie_weight', type=bool, default=args[20])
parser.add_argument('--cuda', type=int, default=args[21])
parser.add_argument('--epochs', type=int, default=args[22])
parser.add_argument('--min_epochs', type=int, default=args[23])
parser.add_argument('--patience', type=int, default=args[24])
parser.add_argument('--seed', type=int, default=args[25])
parser.add_argument('--log_freq', type=int, default=args[26])
parser.add_argument('--eval_freq', type=int, default=args[27])
parser.add_argument('--val_prop', type=float, default=args[28])
parser.add_argument('--test_prop', type=float, default=args[29])
parser.add_argument('--double_precision', type=int, default=args[30])
parser.add_argument('--dropout', type=float, default=args[31])
parser.add_argument('--normalize_adj', type=bool, default=args[32])
parser.add_argument('--normalize_feats', type=bool, default=args[33])
flags, unknown = parser.parse_known_args()
return flags
\ No newline at end of file
__pycache__
from __future__ import print_function
from __future__ import division
from Ghypeddings.HGCAE.models.base_models import LPModel
import logging
import torch
import numpy as np
import os
import time
from Ghypeddings.HGCAE.utils.train_utils import get_dir_name, format_metrics
from Ghypeddings.HGCAE.utils.data_utils import process_data
from Ghypeddings.HGCAE.utils.train_utils import create_args , get_classifier ,get_clustering_algorithm,get_anomaly_detection_algorithm
import Ghypeddings.HGCAE.optimizers as optimizers
from Ghypeddings.HGCAE.utils.data_utils import sparse_mx_to_torch_sparse_tensor
from Ghypeddings.classifiers import calculate_metrics
class HGCAE(object):
def __init__(self,
adj,
features,
labels,
dim,
hidden_dim,
c=None,
num_layers=2,
bias=True,
act='relu',
grad_clip=None,
optimizer='RiemannianAdam',
weight_decay=0.01,
lr=0.001,
gamma=0.5,
lr_reduce_freq=500,
cuda=0,
epochs=50,
min_epochs=50,
patience=None,
seed=42,
log_freq=1,
eval_freq=1,
val_prop=0.0002,
test_prop=0.3,
double_precision=0,
dropout=0.1,
lambda_rec=1.0,
lambda_lp=1.0,
num_dec_layers=2,
use_att= True,
att_type= 'sparse_adjmask_dist',
att_logit='tanh',
beta = 0.2,
classifier=None,
clusterer = None,
normalize_adj=False,
normalize_feats=True,
anomaly_detector=None
):
self.args = create_args(dim,hidden_dim,c,num_layers,bias,act,grad_clip,optimizer,weight_decay,lr,gamma,lr_reduce_freq,cuda,epochs,min_epochs,patience,seed,log_freq,eval_freq,val_prop,test_prop,double_precision,dropout,lambda_rec,lambda_lp,num_dec_layers,use_att,att_type,att_logit,beta,classifier,clusterer,normalize_adj,normalize_feats,anomaly_detector)
self.cls = None
self.args.n_nodes = adj.shape[0]
self.args.feat_dim = features.shape[1]
self.args.n_classes = len(np.unique(labels))
self.data = process_data(self.args,adj,features,labels)
if(self.args.c == None):
self.args.c_trainable = 1
self.args.c = 1.0
np.random.seed(self.args.seed)
torch.manual_seed(self.args.seed)
if int(self.args.double_precision):
torch.set_default_dtype(torch.float64)
if int(self.args.cuda) >= 0:
torch.cuda.manual_seed(self.args.seed)
self.args.device = 'cuda:' + str(self.args.cuda) if int(self.args.cuda) >= 0 else 'cpu'
self.args.patience = self.args.epochs if not self.args.patience else int(self.args.patience)
if not self.args.lr_reduce_freq:
self.args.lr_reduce_freq = self.args.epochs
self.args.nb_false_edges = len(self.data['train_edges_false'])
self.args.nb_edges = len(self.data['train_edges'])
st0 = np.random.get_state()
self.args.np_seed = st0
np.random.set_state(self.args.np_seed)
for x, val in self.data.items():
if 'adj' in x:
self.data[x] = sparse_mx_to_torch_sparse_tensor(self.data[x])
self.model = LPModel(self.args)
if self.args.cuda is not None and int(self.args.cuda) >= 0 :
os.environ['CUDA_VISIBLE_DEVICES'] = str(self.args.cuda)
self.model = self.model.to(self.args.device)
for x, val in self.data.items():
if torch.is_tensor(self.data[x]):
self.data[x] = self.data[x].to(self.args.device)
self.adj_train_enc = self.data['adj_train_enc']
self.optimizer = getattr(optimizers, self.args.optimizer)(params=self.model.parameters(), lr=self.args.lr,
weight_decay=self.args.weight_decay)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer,
step_size=int(self.args.lr_reduce_freq),
gamma=float(self.args.gamma)
)
self.best_emb = None
def fit(self):
logging.getLogger().setLevel(logging.INFO)
logging.info(f'Using: {self.args.device}')
logging.info(str(self.model))
tot_params = sum([np.prod(p.size()) for p in self.model.parameters()])
logging.info(f"Total number of parameters: {tot_params}")
t_total = time.time()
counter = 0
best_val_metrics = self.model.init_metric_dict()
best_losses = []
train_losses = []
val_losses = []
for epoch in range(self.args.epochs):
t = time.time()
self.model.train()
self.optimizer.zero_grad()
embeddings = self.model.encode(self.data['features'], self.adj_train_enc)
train_metrics = self.model.compute_metrics(embeddings, self.data, 'train', epoch)
print(train_metrics)
train_metrics['loss'].backward()
if self.args.grad_clip is not None:
max_norm = float(self.args.grad_clip)
all_params = list(self.model.parameters())
for param in all_params:
torch.nn.utils.clip_grad_norm_(param, max_norm)
self.optimizer.step()
self.lr_scheduler.step()
train_losses.append(train_metrics['loss'].item())
if(len(best_losses) == 0):
best_losses.append(train_losses[0])
elif (best_losses[-1] > train_losses[-1]):
best_losses.append(train_losses[-1])
else:
best_losses.append(best_losses[-1])
with torch.no_grad():
if (epoch + 1) % self.args.log_freq == 0:
logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1),
'lr: {}'.format(self.lr_scheduler.get_lr()[0]),
format_metrics(train_metrics, 'train'),
'time: {:.4f}s'.format(time.time() - t)
]))
if (epoch + 1) % self.args.eval_freq == 0:
self.model.eval()
embeddings = self.model.encode(self.data['features'], self.adj_train_enc)
#val_metrics = self.model.compute_metrics(embeddings, self.data, 'val')
# val_losses.append(val_metrics['loss'].item())
# if (epoch + 1) % self.args.log_freq == 0:
# logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val')]))
# if self.model.has_improved(best_val_metrics, val_metrics):
# self.best_emb = embeddings
# best_val_metrics = val_metrics
# counter = 0
# else:
# counter += 1
# if counter == self.args.patience and epoch > self.args.min_epochs:
# logging.info("Early stopping")
# break
logging.info("Training Finished!")
logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))
# train_idx = np.unique(self.data['train_edges'][:,0].cpu().detach().numpy())
# val_idx = np.unique(self.data['val_edges'][:,0].cpu().detach().numpy())
# idx = np.unique(np.concatenate((train_idx,val_idx)))
# X = self.model.manifold.logmap0(self.best_emb[idx],self.model.encoder.curvatures[-1]).cpu().detach().numpy()
# y = self.data['labels'].reshape(-1,1)[idx]
# if(self.args.classifier):
# self.cls = get_classifier(self.args, X,y)
# acc,f1,recall,precision,roc_auc = calculate_metrics(self.cls,X,y)
# elif self.args.clusterer:
# y = y.reshape(-1,)
# acc,f1,recall,precision,roc_auc = get_clustering_algorithm(self.args.clusterer,X,y)[6:]
# elif self.args.anomaly_detector:
# y = y.reshape(-1,)
# acc,f1,recall,precision,roc_auc = get_anomaly_detection_algorithm(self.args.anomaly_detector,X,y)[6:]
# return {'train':train_losses,'best':best_losses,'val':val_losses},acc,f1,recall,precision,roc_auc , time.time() - t_total
return {'train':train_losses,'best':best_losses,'val':val_losses}, time.time() - t_total
def predict(self):
self.model.eval()
test_idx = np.unique(self.data['test_edges'][:,0].cpu().detach().numpy())
embeddings = self.model.encode(self.data['features'], self.adj_train_enc)
val_metrics = self.model.compute_metrics(embeddings, self.data, 'test')
data = self.model.manifold.logmap0(embeddings[test_idx],self.model.encoder.curvatures[-1]).cpu().detach().numpy()
labels = self.data['labels'].reshape(-1,1)[test_idx]
if self.args.classifier:
acc,f1,recall,precision,roc_auc = calculate_metrics(self.cls,data,labels)
elif self.args.clusterer:
labels = labels.reshape(-1,)
acc,f1,recall,precision,roc_auc = get_clustering_algorithm(self.args.clusterer,data,labels)[6:]
elif self.args.anomaly_detector:
labels = labels.reshape(-1,)
acc,f1,recall,precision,roc_auc = get_anomaly_detection_algorithm(self.args.anomaly_detector,data,labels)[6:]
self.tb_embeddings = embeddings
return val_metrics['loss'].item(),acc,f1,recall,precision,roc_auc
def save_embeddings(self,directory):
self.model.eval()
embeddings = self.model.encode(self.data['features'], self.adj_train_enc)
tb_embeddings_euc = self.model.manifold.logmap0(embeddings,self.model.encoder.curvatures[-1])
for_classification_hyp = np.hstack((embeddings.cpu().detach().numpy(),self.data['labels'].reshape(-1,1)))
for_classification_euc = np.hstack((tb_embeddings_euc.cpu().detach().numpy(),self.data['labels'].reshape(-1,1)))
hyp_file_path = os.path.join(directory,'hgcae_embeddings_hyp.csv')
euc_file_path = os.path.join(directory,'hgcae_embeddings_euc.csv')
np.savetxt(hyp_file_path, for_classification_hyp, delimiter=',')
np.savetxt(euc_file_path, for_classification_euc, delimiter=',')
\ No newline at end of file
"""Attention layers (some modules are copied from https://github.com/Diego999/pyGAT.)"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def HypAggAtt(in_features, manifold, dropout, act=None, att_type=None, att_logit=None, beta=0):
att_logit = get_att_logit(att_logit, att_type)
return GeometricAwareHypAggAtt(in_features, manifold, dropout, lambda x: x, att_logit=att_logit, beta=beta)
class GeometricAwareHypAggAtt(nn.Module):
def __init__(self, in_features, manifold, dropout, act, att_logit=torch.tanh, beta=0.):
super(GeometricAwareHypAggAtt, self).__init__()
self.dropout = dropout
self.att_logit=att_logit
self.special_spmm = SpecialSpmm()
self.m = manifold
self.beta = nn.Parameter(torch.Tensor([1e-6]))
self.con = nn.Parameter(torch.Tensor([1e-6]))
self.act = act
self.in_features = in_features
def forward (self, x, adj, c=1):
n = x.size(0)
edge = adj._indices()
assert not torch.isnan(self.beta).any()
edge_h = self.beta * self.m.sqdist(x[edge[0, :], :], x[edge[1, :], :], c) + self.con
self.edge_h = edge_h
assert not torch.isnan(edge_h).any()
edge_e = self.att_logit(edge_h)
self.edge_e = edge_e
ones = torch.ones(size=(n, 1))
if x.is_cuda:
ones = ones.to(x.device)
e_rowsum = self.special_spmm(edge, abs(edge_e), torch.Size([n, n]), ones) + 1e-10
return edge_e, e_rowsum
class SpecialSpmmFunction(torch.autograd.Function):
"""Special function for only sparse region backpropataion layer."""
# generate sparse matrix from `indicex, values, shape` and matmul with b
# Previously, `AXW` computing did not need bp to `A`.
# To trian attention of `A`, now bp through sparse matrix needed.
@staticmethod
def forward(ctx, indices, values, shape, b):
assert indices.requires_grad == False
a = torch.sparse_coo_tensor(indices, values, shape, device=b.device) # make sparse matrix shaped of `NxN`
ctx.save_for_backward(a, b) # save sparse matrix for bp
ctx.N = shape[0] # number of nodes
return torch.matmul(a, b)
@staticmethod
def backward(ctx, grad_output):
assert not torch.isnan(grad_output).any()
# grad_output : Nxd gradient
# a : NxN adj(attention) matrix, b: Nxd node feature
a, b = ctx.saved_tensors
grad_values = grad_b = None
if ctx.needs_input_grad[1]:
grad_a_dense = grad_output.matmul(b.t())
edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] # flattening (x,y) --> nx + y
grad_values = grad_a_dense.view(-1)[edge_idx]
if ctx.needs_input_grad[3]:
grad_b = a.t().matmul(grad_output)
return None, grad_values, None, grad_b
class SpecialSpmm(nn.Module):
def forward(self, indices, values, shape, b):
return SpecialSpmmFunction.apply(indices, values, shape, b)
def get_att_logit(att_logit, att_type):
if att_logit:
att_logit = getattr(torch, att_logit)
return att_logit
"""
Hyperbolic layers.
Major codes of hyperbolic layers are from HGCN
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
from Ghypeddings.HGCAE.layers.att_layers import HypAggAtt, SpecialSpmm
def get_dim_act_curv(args):
"""
Helper function to get dimension and activation at every layer.
:param args:
:return:
"""
if not args.act:
act = lambda x: x
else:
act = getattr(F, args.act)
acts = [act] * (args.num_layers - 1)
dims = [args.feat_dim]
# Check layer_num and hdden_dim match
if args.num_layers > 1:
hidden_dim = [args.hidden_dim for _ in range(args.num_layers -1)]
if args.num_layers != len(hidden_dim) + 1:
raise RuntimeError('Check dimension hidden:{}, num_layers:{}'.format(args.hidden_dim, args.num_layers) )
dims = dims + hidden_dim
dims += [args.dim]
acts += [act]
n_curvatures = args.num_layers
if args.c_trainable == 1: # NOTE : changed from # if args.c is None:
# create list of trainable curvature parameters
curvatures = [nn.Parameter(torch.Tensor([args.c]).to(args.device)) for _ in range(n_curvatures)]
else:
# fixed curvature
curvatures = [torch.tensor([args.c]) for _ in range(n_curvatures)]
if not args.cuda == -1:
curvatures = [curv.to(args.device) for curv in curvatures]
return dims, acts, curvatures
class HNNLayer(nn.Module):
"""
Hyperbolic neural networks layer.
"""
def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias):
super(HNNLayer, self).__init__()
self.linear = HypLinear(manifold, in_features, out_features, c_in, dropout, use_bias)
self.hyp_act = HypAct(manifold, c_in, c_out, act)
def forward(self, x):
h = self.linear.forward(x)
h = self.hyp_act.forward(h)
return h
class HyperbolicGraphConvolution(nn.Module):
"""
Hyperbolic graph convolution layer.
"""
def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias, use_att,
att_type='sparse_adjmask_dist', att_logit=torch.exp, beta=0., decode=False):
super(HyperbolicGraphConvolution, self).__init__()
self.linear = HypLinear(manifold, in_features, out_features, c_in, dropout, use_bias)
self.agg = HypAgg(manifold, c_in, use_att, out_features, dropout, att_type=att_type, att_logit=att_logit, beta=beta, decode=decode)
self.hyp_act = HypAct(manifold, c_in, c_out, act)
self.decode = decode
def forward(self, input):
x, adj = input
assert not torch.isnan(self.hyp_act.c_in).any()
self.hyp_act.c_in.data = torch.clamp_min(self.hyp_act.c_in,1e-12)
if self.hyp_act.c_out:
assert not torch.isnan(self.hyp_act.c_out).any()
self.hyp_act.c_out.data = torch.clamp_min(self.hyp_act.c_out,1e-12)
assert not torch.isnan(x).any()
h = self.linear.forward(x)
assert not torch.isnan(h).any()
h = self.agg.forward(h, adj, prev_x=x)
assert not torch.isnan(h).any()
h = self.hyp_act.forward(h)
assert not torch.isnan(h).any()
output = h, adj
return output
class HypLinear(nn.Module):
"""
Hyperbolic linear layer.
"""
def __init__(self, manifold, in_features, out_features, c, dropout, use_bias):
super(HypLinear, self).__init__()
self.manifold = manifold
self.in_features = in_features
self.out_features = out_features
self.c = c
self.dropout = dropout
self.use_bias = use_bias
# self.bias = nn.Parameter(torch.Tensor(out_features))
self.bias = nn.Parameter(torch.Tensor(1, out_features))
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.reset_parameters()
def reset_parameters(self):
init.xavier_uniform_(self.weight, gain=math.sqrt(2))
init.constant_(self.bias, 0)
def forward(self, x):
drop_weight = F.dropout(self.weight, self.dropout, training=self.training)
mv = self.manifold.mobius_matvec(drop_weight, x, self.c)
res = self.manifold.proj(mv, self.c)
if self.use_bias:
bias = self.bias
hyp_bias = self.manifold.expmap0(bias, self.c)
hyp_bias = self.manifold.proj(hyp_bias, self.c)
res = self.manifold.mobius_add(res, hyp_bias, c=self.c)
res = self.manifold.proj(res, self.c)
return res
def extra_repr(self):
return 'in_features={}, out_features={}, c={}'.format(
self.in_features, self.out_features, self.c
)
class HypAgg(Module):
"""
Hyperbolic aggregation layer.
"""
def __init__(self, manifold, c, use_att, in_features, dropout, att_type='sparse_adjmask_dist', att_logit=None, beta=0, decode=False):
super(HypAgg, self).__init__()
self.manifold = manifold
self.c = c
self.use_att = use_att
self.in_features = in_features
self.dropout = dropout
if use_att:
self.att = HypAggAtt(in_features, manifold, dropout, act=None, att_type=att_type, att_logit=att_logit, beta=beta)
self.att_type = att_type
self.special_spmm = SpecialSpmm()
self.decode = decode
def forward(self, x, adj, prev_x=None):
if self.use_att:
dist = 'dist' in self.att_type
if dist:
if 'sparse' in self.att_type:
if self.decode:
# NOTE : AGG(prev_x)
edge_e, e_rowsum = self.att(prev_x, adj, self.c) # SparseAtt
else:
# NOTE : AGG(x)
edge_e, e_rowsum = self.att(x, adj, self.c) # SparseAtt
self.edge_e = edge_e
self.e_rowsum = e_rowsum
## SparseAtt
x_tangent = self.manifold.logmap0(x, c=self.c)
N = x.size()[0]
edge = adj._indices()
support_t = self.special_spmm(edge, edge_e, torch.Size([N, N]), x_tangent)
assert not torch.isnan(support_t).any()
support_t = support_t.div(e_rowsum)
assert not torch.isnan(support_t).any()
output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c)
else:
adj = self.att(x, adj, self.c) # DenseAtt
x_tangent = self.manifold.logmap0(x, c=self.c)
support_t = torch.spmm(adj, x_tangent)
output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c)
else:
## MLP attention
x_tangent = self.manifold.logmap0(x, c=self.c)
adj = self.att(x_tangent, adj)
support_t = torch.spmm(adj, x_tangent)
output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c)
else:
x_tangent = self.manifold.logmap0(x, c=self.c)
support_t = torch.spmm(adj, x_tangent)
output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c)
return output
def extra_repr(self):
return 'c={}, use_att={}, decode={}'.format(
self.c, self.use_att, self.decode
)
class HypAct(Module):
"""
Hyperbolic activation layer.
"""
def __init__(self, manifold, c_in, c_out, act):
super(HypAct, self).__init__()
self.manifold = manifold
self.c_in = c_in
self.c_out = c_out
self.act = act
def forward(self, x):
if self.manifold.name == 'PoincareBall':
if self.c_out:
xt = self.manifold.activation(x, self.act, self.c_in, self.c_out)
return xt
else:
xt = self.manifold.logmap0(x, c=self.c_in)
return xt
else:
NotImplementedError("not implemented")
def extra_repr(self):
return 'Manifold={},\n c_in={},\n act={},\n c_out={}'.format(
self.manifold.name, self.c_in, self.act.__name__, self.c_out
)
"""Euclidean layers."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
def get_dim_act(args):
"""
Helper function to get dimension and activation at every layer.
:param args:
:return:
"""
if not args.act:
act = lambda x: x
else:
act = getattr(F, args.act)
acts = [act] * (args.num_layers - 1)
dims = [args.feat_dim]
if args.num_layers > 1:
# Check layer_num and hdden_dim match
hidden_dim = [int(h) for h in args.hidden_dim.split(',')]
if args.num_layers != len(hidden_dim) + 1:
raise RuntimeError('Check dimension hidden:{}, num_laysers:{}'.format(args.hidden_dim, args.num_layers) )
dims = dims + hidden_dim
dims += [args.dim]
acts += [act]
return dims, acts
class Linear(Module):
"""
Simple Linear layer with dropout.
"""
def __init__(self, in_features, out_features, dropout, act, use_bias):
super(Linear, self).__init__()
self.dropout = dropout
self.linear = nn.Linear(in_features, out_features, use_bias)
self.act = act
def forward(self, x):
hidden = self.linear.forward(x)
hidden = F.dropout(hidden, self.dropout, training=self.training)
out = self.act(hidden)
return out
'''
InnerProductDecdoer implemntation from:
https://github.com/zfjsail/gae-pytorch/blob/master/gae/model.py
'''
class InnerProductDecoder(nn.Module):
"""Decoder for using inner product for prediction."""
def __init__(self, dropout=0, act=torch.sigmoid):
super(InnerProductDecoder, self).__init__()
self.dropout = dropout
self.act = act
def forward(self, emb_in, emb_out):
cos_dist = emb_in * emb_out
probs = self.act(cos_dist.sum(1))
return probs
'''
Major codes of hyperbolic layers are from HGCN
Refer Lorentz implementation from HGCN if you need.
'''
from Ghypeddings.HGCAE.manifolds.base import ManifoldParameter
from Ghypeddings.HGCAE.manifolds.euclidean import Euclidean
from Ghypeddings.HGCAE.manifolds.poincare import PoincareBall
'''
Major codes of hyperbolic layers are from HGCN
'''
from torch.nn import Parameter
class Manifold(object):
"""
Abstract class to define operations on a manifold.
"""
def __init__(self):
super().__init__()
self.eps = 10e-8
def sqdist(self, p1, p2, c):
"""Squared distance between pairs of points."""
raise NotImplementedError
def egrad2rgrad(self, p, dp, c):
"""Converts Euclidean Gradient to Riemannian Gradients."""
raise NotImplementedError
def proj(self, p, c):
"""Projects point p on the manifold."""
raise NotImplementedError
def proj_tan(self, u, p, c):
"""Projects u on the tangent space of p."""
raise NotImplementedError
def proj_tan0(self, u, c):
"""Projects u on the tangent space of the origin."""
raise NotImplementedError
def expmap(self, u, p, c):
"""Exponential map of u at point p."""
raise NotImplementedError
def logmap(self, p1, p2, c):
"""Logarithmic map of point p1 at point p2."""
raise NotImplementedError
def expmap0(self, u, c):
"""Exponential map of u at the origin."""
raise NotImplementedError
def logmap0(self, p, c):
"""Logarithmic map of point p at the origin."""
raise NotImplementedError
def mobius_add(self, x, y, c, dim=-1):
"""Adds points x and y."""
raise NotImplementedError
def mobius_matvec(self, m, x, c):
"""Performs hyperboic martrix-vector multiplication."""
raise NotImplementedError
def init_weights(self, w, c, irange=1e-5):
"""Initializes random weigths on the manifold."""
raise NotImplementedError
def inner(self, p, c, u, v=None):
"""Inner product for tangent vectors at point x."""
raise NotImplementedError
def ptransp(self, x, y, u, c):
"""Parallel transport of u from x to y."""
raise NotImplementedError
class ManifoldParameter(Parameter):
"""
Subclass of torch.nn.Parameter for Riemannian optimization.
"""
def __new__(cls, data, requires_grad, manifold, c):
return Parameter.__new__(cls, data, requires_grad)
def __init__(self, data, requires_grad, manifold, c):
self.c = c
self.manifold = manifold
def __repr__(self):
return '{} Parameter containing:\n'.format(self.manifold.name) + super(Parameter, self).__repr__()
'''
Major codes of hyperbolic layers are from HGCN
'''
import torch
from Ghypeddings.HGCAE.manifolds.base import Manifold
class Euclidean(Manifold):
"""
Euclidean Manifold class.
"""
def __init__(self):
super(Euclidean, self).__init__()
self.name = 'Euclidean'
def normalize(self, p):
dim = p.size(-1)
p_norm = torch.renorm(p, 2, 0, 1.)
return p_norm
def sqdist(self, p1, p2, c):
return (p1 - p2).pow(2).sum(dim=-1)
def egrad2rgrad(self, p, dp, c):
return dp
def proj(self, p, c):
return p
def proj_tan(self, u, p, c):
return u
def proj_tan0(self, u, c):
return u
def expmap(self, u, p, c):
return p + u
def logmap(self, p1, p2, c):
return p2 - p1
def expmap0(self, u, c):
return u
def logmap0(self, p, c):
return p
def mobius_add(self, x, y, c, dim=-1):
return x + y
def mobius_matvec(self, m, x, c):
mx = x @ m.transpose(-1, -2)
return mx
def init_weights(self, w, c, irange=1e-5):
w.data.uniform_(-irange, irange)
return w
def inner(self, p, c, u, v=None, keepdim=False):
if v is None:
v = u
return (u * v).sum(dim=-1, keepdim=keepdim)
def ptransp(self, x, y, v, c):
return v
'''
Major codes of hyperbolic layers are from HGCN
'''
import torch
from Ghypeddings.HGCAE.manifolds.base import Manifold
from torch.autograd import Function
from Ghypeddings.HGCAE.utils.math_utils import artanh, tanh
class PoincareBall(Manifold):
"""
PoicareBall Manifold class.
We use the following convention: x0^2 + x1^2 + ... + xd^2 < 1 / c
Note that 1/sqrt(c) is the Poincare ball radius.
"""
def __init__(self, ):
super(PoincareBall, self).__init__()
self.name = 'PoincareBall'
self.min_norm = 1e-15
self.eps = {torch.float32: 4e-3, torch.float64: 1e-5}
def sqdist(self, p1, p2, c):
sqrt_c = c ** 0.5
dist_c = artanh(
sqrt_c * self.mobius_add(-p1, p2, c, dim=-1).norm(dim=-1, p=2, keepdim=False)
)
dist = dist_c * 2 / sqrt_c
return dist ** 2
def _lambda_x(self, x, c):
x_sqnorm = torch.sum(x.data.pow(2), dim=-1, keepdim=True)
return 2 / (1. - c * x_sqnorm).clamp_min(self.min_norm)
def egrad2rgrad(self, p, dp, c):
lambda_p = self._lambda_x(p, c)
dp /= lambda_p.pow(2)
return dp
def proj(self, x, c):
norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), self.min_norm)
maxnorm = (1 - self.eps[x.dtype]) / (c ** 0.5)
cond = norm > maxnorm
projected = x / norm * maxnorm
return torch.where(cond, projected, x)
def proj_tan(self, u, p, c):
return u
def proj_tan0(self, u, c):
return u
def expmap(self, u, p, c):
sqrt_c = c ** 0.5
u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
second_term = (
tanh(sqrt_c / 2 * self._lambda_x(p, c) * u_norm)
* u
/ (sqrt_c * u_norm)
)
gamma_1 = self.mobius_add(p, second_term, c)
return gamma_1
def logmap(self, p1, p2, c):
sub = self.mobius_add(-p1, p2, c)
sub_norm = sub.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
lam = self._lambda_x(p1, c)
sqrt_c = c ** 0.5
return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm
def expmap0(self, u, c):
sqrt_c = c ** 0.5
u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), self.min_norm)
gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
return gamma_1
def logmap0(self, p, c):
sqrt_c = c ** 0.5
p_norm = p.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
scale = 1. / sqrt_c * artanh(sqrt_c * p_norm) / p_norm
return scale * p
def mobius_add(self, x, y, c, dim=-1):
x2 = x.pow(2).sum(dim=dim, keepdim=True)
y2 = y.pow(2).sum(dim=dim, keepdim=True)
xy = (x * y).sum(dim=dim, keepdim=True)
num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
return num / denom.clamp_min(self.min_norm)
def mobius_matvec(self, m, x, c):
sqrt_c = c ** 0.5
x_norm = x.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
mx = x @ m.transpose(-1, -2)
mx_norm = mx.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c)
cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8)
res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
res = torch.where(cond, res_0, res_c)
return res
def init_weights(self, w, c, irange=1e-5):
w.data.uniform_(-irange, irange)
return w
def _gyration(self, u, v, w, c, dim: int = -1):
u2 = u.pow(2).sum(dim=dim, keepdim=True)
v2 = v.pow(2).sum(dim=dim, keepdim=True)
uv = (u * v).sum(dim=dim, keepdim=True)
uw = (u * w).sum(dim=dim, keepdim=True)
vw = (v * w).sum(dim=dim, keepdim=True)
c2 = c ** 2
a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw
b = -c2 * vw * u2 - c * uw
d = 1 + 2 * c * uv + c2 * u2 * v2
return w + 2 * (a * u + b * v) / d.clamp_min(self.min_norm)
def inner(self, x, c, u, v=None, keepdim=False, dim=-1):
if v is None:
v = u
lambda_x = self._lambda_x(x, c)
return lambda_x ** 2 * (u * v).sum(dim=dim, keepdim=keepdim)
def ptransp(self, x, y, u, c):
lambda_x = self._lambda_x(x, c)
lambda_y = self._lambda_x(y, c)
return self._gyration(y, -x, u, c) * lambda_x / lambda_y
def activation(self, x, act, c_in, c_out):
x_act = act(x)
x_prev = self.logmap0(x_act, c_in)
x_next = self.expmap0(x_prev, c_out)
return x_next
import Ghypeddings.HGCAE.models.encoders as encoders
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from Ghypeddings.HGCAE.models.decoders import model2decoder
from Ghypeddings.HGCAE.layers.layers import InnerProductDecoder
from sklearn.metrics import roc_auc_score, average_precision_score
from Ghypeddings.HGCAE.utils.eval_utils import acc_f1
from sklearn import cluster
from sklearn.metrics import accuracy_score, normalized_mutual_info_score, adjusted_rand_score
import Ghypeddings.HGCAE.manifolds as manifolds
import Ghypeddings.HGCAE.models.encoders as encoders
class BaseModel(nn.Module):
"""
Base model for graph embedding tasks.
"""
def __init__(self, args):
super(BaseModel, self).__init__()
self.manifold_name = "PoincareBall"
if args.c is not None:
self.c = torch.tensor([args.c])
if not args.cuda == -1:
self.c = self.c.to(args.device)
else:
self.c = nn.Parameter(torch.Tensor([1.]))
self.manifold = getattr(manifolds, self.manifold_name)()
self.nnodes = args.n_nodes
self.n_classes = args.n_classes
self.encoder = getattr(encoders, "HGCAE")(self.c, args)
self.num_layers=args.num_layers
# Embedding c
self.hyperbolic_embedding = True if args.use_att else False
self.decoder_type = 'InnerProductDecoder'
self.dc = InnerProductDecoder(dropout=0, act=torch.sigmoid)
def encode(self, x, adj):
h = self.encoder.encode(x, adj)
return h
def pred_link_score(self, h, idx): # for LP,REC
emb_in = h[idx[:, 0], :]
emb_out = h[idx[:, 1], :]
probs = self.dc.forward(emb_in, emb_out)
return probs
def decode(self, h, adj, idx): # REC
output = self.decoder.decode(h, adj)
return output
def eval_cluster(self, embeddings, data, split):
if self.hyperbolic_embedding:
emb_c = self.encoder.layers[-1].hyp_act.c_out
embeddings = self.manifold.logmap0(embeddings.to(emb_c.device), c=emb_c).cpu()
idx = data[f'idx_{split}']
n_classes = self.n_classes
embeddings_to_cluster = embeddings[idx].detach().cpu().numpy()
# gt_label = data['labels'][idx].cpu().numpy()
gt_label = data['labels']
kmeans = cluster.KMeans(n_clusters=n_classes, algorithm='auto')
kmeans.fit(embeddings_to_cluster)
pred_label = kmeans.fit_predict(embeddings_to_cluster)
from munkres import Munkres
def best_map(L1,L2):
#L1 should be the groundtruth labels and L2 should be the clustering labels we got
Label1 = np.unique(L1)
nClass1 = len(Label1)
Label2 = np.unique(L2)
nClass2 = len(Label2)
nClass = np.maximum(nClass1,nClass2)
G = np.zeros((nClass,nClass))
for i in range(nClass1):
ind_cla1 = L1 == Label1[i]
ind_cla1 = ind_cla1.astype(float)
for j in range(nClass2):
ind_cla2 = L2 == Label2[j]
ind_cla2 = ind_cla2.astype(float)
G[i,j] = np.sum(ind_cla2 * ind_cla1)
m = Munkres()
index = m.compute(-G.T)
index = np.array(index)
c = index[:,1]
newL2 = np.zeros(L2.shape)
for i in range(nClass2):
newL2[L2 == Label2[i]] = Label1[c[i]]
return newL2
def err_rate(gt_s, s):
c_x = best_map(gt_s, s)
err_x = np.sum(gt_s[:] !=c_x[:])
missrate = err_x.astype(float) / (gt_s.shape[0])
return missrate
acc = 1-err_rate(gt_label, pred_label)
# acc = accuracy_score(gt_label, pred_label)
nmi = normalized_mutual_info_score(gt_label, pred_label, average_method='arithmetic')
ari = adjusted_rand_score(gt_label, pred_label)
metrics = { 'cluster_acc': acc, 'nmi': nmi, 'ari': ari}
return metrics, pred_label
def compute_metrics(self, embeddings, data, split, epoch=None):
raise NotImplementedError
def init_metric_dict(self):
raise NotImplementedError
def has_improved(self, m1, m2):
raise NotImplementedError
class LPModel(BaseModel):
"""
Base model for link prediction task.
"""
def __init__(self, args):
super(LPModel, self).__init__(args)
self.nb_false_edges = args.nb_false_edges
self.positive_edge_samplig = True
if self.positive_edge_samplig:
self.nb_edges = min(args.nb_edges, 5000) # NOTE : be-aware too dense edges
else:
self.nb_edges = args.nb_edges
if args.lambda_rec > 0:
self.num_dec_layers = args.num_dec_layers
self.lambda_rec = args.lambda_rec
c = self.encoder.curvatures if hasattr(self.encoder, 'curvatures') else args.c ### handle HNN
self.decoder = model2decoder(c, args, 'rec')
else:
self.lambda_rec = 0
if args.lambda_lp > 0:
self.lambda_lp = args.lambda_lp
else:
self.lambda_lp = 0
def compute_metrics(self, embeddings, data, split, epoch=None):
if split == 'train':
num_true_edges = data[f'{split}_edges'].shape[0]
if self.positive_edge_samplig and num_true_edges > self.nb_edges:
edges_true = data[f'{split}_edges'][np.random.randint(0, num_true_edges, self.nb_edges)]
else:
edges_true = data[f'{split}_edges']
edges_false = data[f'{split}_edges_false'][np.random.randint(0, self.nb_false_edges, self.nb_edges)]
else:
edges_true = data[f'{split}_edges']
edges_false = data[f'{split}_edges_false']
pos_scores = self.pred_link_score(embeddings, edges_true)
neg_scores = self.pred_link_score(embeddings, edges_false)
assert not torch.isnan(pos_scores).any()
assert not torch.isnan(neg_scores).any()
loss = F.binary_cross_entropy(pos_scores, torch.ones_like(pos_scores))
loss += F.binary_cross_entropy(neg_scores, torch.zeros_like(neg_scores))
if pos_scores.is_cuda:
pos_scores = pos_scores.cpu()
neg_scores = neg_scores.cpu()
labels = [1] * pos_scores.shape[0] + [0] * neg_scores.shape[0]
preds = list(pos_scores.data.numpy()) + list(neg_scores.data.numpy())
roc = roc_auc_score(labels, preds)
ap = average_precision_score(labels, preds)
metrics = {'loss': loss, 'roc': roc, 'ap': ap}
assert not torch.isnan(loss).any()
if self.lambda_rec:
idx = data['idx_all']
recon = self.decode(embeddings, data['adj_train_dec'], idx) ## NOTE : adj
assert not torch.isnan(recon).any()
if self.num_dec_layers == self.num_layers:
target = data['features'][idx]
elif self.num_dec_layers == self.num_layers - 1:
target = self.encoder.features[0].detach()[idx]
else:
raise RuntimeError('num_dec_layers only support 1,2')
loss_rec = self.lambda_rec * torch.nn.functional.mse_loss(recon[idx], target , reduction='mean')
assert not torch.isnan(loss_rec).any()
loss_lp = loss * self.lambda_lp
metrics.update({'loss': loss_lp + loss_rec, 'loss_rec': loss_rec, 'loss_lp': loss_lp})
return metrics
def init_metric_dict(self):
return {'roc': -1, 'ap': -1}
def has_improved(self, m1, m2):
return 0.5 * (m1['roc'] + m1['ap']) < 0.5 * (m2['roc'] + m2['ap'])
"""Graph decoders."""
import Ghypeddings.HGCAE.manifolds as manifolds
import torch.nn as nn
import torch.nn.functional as F
import torch
class Decoder(nn.Module):
"""
Decoder abstract class
"""
def __init__(self, c):
super(Decoder, self).__init__()
self.c = c
def classify(self, x, adj):
'''
output
- nc : probs
- rec : input_feat
'''
if self.decode_adj:
input = (x, adj)
output, _ = self.classifier.forward(input)
else:
output = self.classifier.forward(x)
return output
def decode(self, x, adj):
'''
output
- nc : probs
- rec : input_feat
'''
if self.decode_adj:
input = (x, adj)
output, _ = self.decoder.forward(input)
else:
output = self.decoder.forward(x)
return output
import Ghypeddings.HGCAE.layers.hyp_layers as hyp_layers
class HGCAEDecoder(Decoder):
"""
Decoder for HGCAE
"""
def __init__(self, c, args, task):
super(HGCAEDecoder, self).__init__(c)
self.manifold = getattr(manifolds, 'PoincareBall')()
assert args.num_layers > 0
dims, acts, _ = hyp_layers.get_dim_act_curv(args)
dims = dims[::-1]
acts = acts[::-1][:-1] + [lambda x: x] # Last layer without act
self.curvatures = self.c[::-1]
encdec_share_curvature = False
if not encdec_share_curvature and args.num_layers == args.num_dec_layers: # do not share and enc-dec mirror-shape
num_c = len(self.curvatures)
self.curvatures = self.curvatures[:1]
if args.c_trainable == 1:
self.curvatures += [nn.Parameter(torch.Tensor([args.c]).to(args.device))] * (num_c - 1)
else:
self.curvatures += [torch.tensor([args.c])] * (num_c - 1)
if not args.cuda == -1:
self.curvatures = [curv.to(args.device) for curv in self.curvatures]
self.curvatures = self.curvatures[:-1] + [None]
hgc_layers = []
num_dec_layers = args.num_dec_layers
for i in range(num_dec_layers):
c_in, c_out = self.curvatures[i], self.curvatures[i + 1]
in_dim, out_dim = dims[i], dims[i + 1]
act = acts[i]
hgc_layers.append(
hyp_layers.HyperbolicGraphConvolution(
self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias, args.use_att,
att_type=args.att_type, att_logit=args.att_logit, beta=args.beta, decode=True
)
)
self.decoder = nn.Sequential(*hgc_layers)
self.decode_adj = True
# NOTE : self.c is fixed, not trainable
def classify(self, x, adj):
h = self.manifold.logmap0(x, c=self.c)
return super(HGCAEDecoder, self).classify(h, adj)
def decode(self, x, adj):
output = super(HGCAEDecoder, self).decode(x, adj)
return output
model2decoder = HGCAEDecoder
"""Graph encoders."""
import Ghypeddings.HGCAE.manifolds as manifolds
import Ghypeddings.HGCAE.layers.hyp_layers as hyp_layers
import torch
import torch.nn as nn
class Encoder(nn.Module):
"""
Encoder abstract class.
"""
def __init__(self, c, use_cnn=None):
super(Encoder, self).__init__()
self.c = c
def encode(self, x, adj):
self.features = []
if self.encode_graph:
input = (x, adj)
xx = input
for i in range(len(self.layers)):
out = self.layers[i].forward(xx)
self.features.append(out[0])
xx = out
output , _ = xx
else:
output = self.layers.forward(x)
return output
class HGCAE(Encoder):
"""
Hyperbolic Graph Convolutional Auto-Encoders.
"""
def __init__(self, c, args): #, use_cnn
super(HGCAE, self).__init__(c, use_cnn=True)
self.manifold = getattr(manifolds, "PoincareBall")()
assert args.num_layers > 0
dims, acts, self.curvatures = hyp_layers.get_dim_act_curv(args)
if args.c_trainable == 1:
self.curvatures.append(nn.Parameter(torch.Tensor([args.c]).to(args.device)))
else:
self.curvatures.append(torch.tensor([args.c]).to(args.device))
hgc_layers = []
for i in range(len(dims) - 1):
c_in, c_out = self.curvatures[i], self.curvatures[i + 1]
in_dim, out_dim = dims[i], dims[i + 1]
act = acts[i]
hgc_layers.append(
hyp_layers.HyperbolicGraphConvolution(
self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias, args.use_att,
att_type=args.att_type, att_logit=args.att_logit, beta=args.beta
)
)
self.layers = nn.Sequential(*hgc_layers)
self.encode_graph = True
def encode(self, x, adj):
self.curvatures[0] = torch.clamp_min(self.curvatures[0],min=1e-12)
x_hyp = self.manifold.proj(
self.manifold.expmap0(self.manifold.proj_tan0(x, self.curvatures[0]), c=self.curvatures[0]),
c=self.curvatures[0])
return super(HGCAE, self).encode(x_hyp, adj)
from torch.optim import Adam
from Ghypeddings.HGCAE.optimizers.radam import RiemannianAdam
"""Riemannian adam optimizer geoopt implementation (https://github.com/geoopt/)."""
import torch.optim
from Ghypeddings.HGCAE.manifolds import Euclidean,ManifoldParameter
_default_manifold = Euclidean()
class OptimMixin(object):
def __init__(self, *args, stabilize=None, **kwargs):
self._stabilize = stabilize
super().__init__(*args, **kwargs)
def stabilize_group(self, group):
pass
def stabilize(self):
"""Stabilize parameters if they are off-manifold due to numerical reasons
"""
for group in self.param_groups:
self.stabilize_group(group)
def copy_or_set_(dest, source):
"""
A workaround to respect strides of :code:`dest` when copying :code:`source`
(https://github.com/geoopt/geoopt/issues/70)
Parameters
----------
dest : torch.Tensor
Destination tensor where to store new data
source : torch.Tensor
Source data to put in the new tensor
Returns
-------
dest
torch.Tensor, modified inplace
"""
if dest.stride() != source.stride():
return dest.copy_(source)
else:
return dest.set_(source)
class RiemannianAdam(OptimMixin, torch.optim.Adam):
r"""Riemannian Adam with the same API as :class:`torch.optim.Adam`
Parameters
----------
params : iterable
iterable of parameters to optimize or dicts defining
parameter groups
lr : float (optional)
learning rate (default: 1e-3)
betas : Tuple[float, float] (optional)
coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps : float (optional)
term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay : float (optional)
weight decay (L2 penalty) (default: 0)
amsgrad : bool (optional)
whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
Other Parameters
----------------
stabilize : int
Stabilize parameters if they are off-manifold due to numerical
reasons every ``stabilize`` steps (default: ``None`` -- no stabilize)
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def step(self, closure=None):
"""Performs a single optimization step.
Arguments
---------
closure : callable (optional)
A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
with torch.no_grad():
for group in self.param_groups:
if "step" not in group:
group["step"] = 0
betas = group["betas"]
weight_decay = group["weight_decay"]
eps = group["eps"]
learning_rate = group["lr"]
amsgrad = group["amsgrad"]
for point in group["params"]:
grad = point.grad
if grad is None:
continue
if isinstance(point, (ManifoldParameter)):
manifold = point.manifold
c = point.c
else:
manifold = _default_manifold
c = None
if grad.is_sparse:
raise RuntimeError(
"Riemannian Adam does not support sparse gradients yet (PR is welcome)"
)
state = self.state[point]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(point)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(point)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(point)
# make local variables for easy access
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
# actual step
grad.add_(weight_decay, point)
grad = manifold.egrad2rgrad(point, grad, c)
exp_avg.mul_(betas[0]).add_(1 - betas[0], grad)
exp_avg_sq.mul_(betas[1]).add_(
1 - betas[1], manifold.inner(point, c, grad, keepdim=True)
)
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(eps)
else:
denom = exp_avg_sq.sqrt().add_(eps)
group["step"] += 1
bias_correction1 = 1 - betas[0] ** group["step"]
bias_correction2 = 1 - betas[1] ** group["step"]
step_size = (
learning_rate * bias_correction2 ** 0.5 / bias_correction1
)
# copy the state, we need it for retraction
# get the direction for ascend
direction = exp_avg / denom
# transport the exponential averaging to the new point
new_point = manifold.proj(manifold.expmap(-step_size * direction, point, c), c)
exp_avg_new = manifold.ptransp(point, new_point, exp_avg, c)
# use copy only for user facing point
copy_or_set_(point, new_point)
exp_avg.set_(exp_avg_new)
group["step"] += 1
if self._stabilize is not None and group["step"] % self._stabilize == 0:
self.stabilize_group(group)
return loss
@torch.no_grad()
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, ManifoldParameter):
continue
state = self.state[p]
if not state: # due to None grads
continue
manifold = p.manifold
c = p.c
exp_avg = state["exp_avg"]
copy_or_set_(p, manifold.proj(p, c))
exp_avg.set_(manifold.proj_tan(exp_avg, u, c))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment