Skip to content
Snippets Groups Projects
Commit bd058dcf authored by yacinetouahria's avatar yacinetouahria
Browse files

add code

parent b5c2a6ea
No related branches found
No related tags found
1 merge request!2Master
Showing
with 1854 additions and 0 deletions
__pycache__/
\ No newline at end of file
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.nn as nn
import torch.nn.functional as F
class GraphConvolution(Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""
def __init__(self, in_features, out_features, bias=True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
if self.bias is not None:
return output + self.bias
else:
return output
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
import os
import pickle
from runner import RUNNER
import numpy as np
import pandas as pd
def load_samples(dataset):
features_path=os.path.join(os.getcwd(),'datasets',dataset,f'features.pkl')
adj_path=os.path.join(os.getcwd(),'datasets',dataset,f'adjacency.pkl')
labels_path=os.path.join(os.getcwd(),'datasets',dataset,f'labels.pkl')
with open(adj_path,'rb') as f:
adj = pickle.load(f)
with open(features_path,'rb') as f:
features = pickle.load(f)
with open(labels_path,'rb') as f:
labels= pickle.load(f)
return adj,features,labels
if __name__ == "__main__":
dims = [2,5,10,20]
outlier_percentage = 0.1
lr = 0.001
invert = True
dataset = 'unsw_nb'
df = pd.DataFrame(None)
for dim in dims:
adj,features,labels = load_samples(dataset=dataset)
runner = RUNNER(adj=adj,features=features,labels=labels,lr=lr,alpha=0.5,dim=dim)
loss,accuracy,f1,recall,precision,roc_auc,training_time = runner.fit(outlier_percentage=outlier_percentage,invert=invert)
print('TRANING:',accuracy,f1,recall,precision,roc_auc,training_time)
loss,accuracy,f1,recall,precision,roc_auc = runner.predict(outlier_percentage=outlier_percentage,invert=invert)
print('TESTING:',accuracy,f1,recall,precision,roc_auc)
df[f'acc_{dim}'] = [round(accuracy,4)]
df[f'f1_{dim}'] = [round(f1,4)]
df[f'rec_{dim}'] = [round(recall,4)]
df[f'pre_{dim}'] = [round(precision,4)]
df[f'roc_{dim}'] = [round(roc_auc,4)]
df[f'tt_{dim}'] = [round(training_time,4)]
print(df)
df.to_csv(f'{dataset}_results.csv')
import torch.nn as nn
import torch.nn.functional as F
import torch
from layers import GraphConvolution
class Encoder(nn.Module):
def __init__(self, nfeat, nhid, dropout):
super(Encoder, self).__init__()
self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nhid)
self.dropout = dropout
def forward(self, x, adj):
x = F.relu(self.gc1(x, adj))
x = F.dropout(x, self.dropout, training=self.training)
x = F.relu(self.gc2(x, adj))
return x
class Attribute_Decoder(nn.Module):
def __init__(self, nfeat, nhid, dropout):
super(Attribute_Decoder, self).__init__()
self.gc1 = GraphConvolution(nhid, nhid)
self.gc2 = GraphConvolution(nhid, nfeat)
self.dropout = dropout
def forward(self, x, adj):
x = F.relu(self.gc1(x, adj))
x = F.dropout(x, self.dropout, training=self.training)
x = F.relu(self.gc2(x, adj))
return x
class Structure_Decoder(nn.Module):
def __init__(self, nhid, dropout):
super(Structure_Decoder, self).__init__()
self.gc1 = GraphConvolution(nhid, nhid)
self.dropout = dropout
def forward(self, x, adj):
x = F.relu(self.gc1(x, adj))
x = F.dropout(x, self.dropout, training=self.training)
x = x @ x.T
return x
class Dominant(nn.Module):
def __init__(self, feat_size, hidden_size, dropout):
super(Dominant, self).__init__()
self.shared_encoder = Encoder(feat_size, hidden_size, dropout)
self.attr_decoder = Attribute_Decoder(feat_size, hidden_size, dropout)
self.struct_decoder = Structure_Decoder(hidden_size, dropout)
def forward(self, x, adj):
# encode
x = self.shared_encoder(x, adj)
# decode feature matrix
x_hat = self.attr_decoder(x, adj)
# decode adjacency matrix
struct_reconstructed = self.struct_decoder(x, adj)
# return reconstructed matrices
return struct_reconstructed, x_hat
\ No newline at end of file
from scipy.sparse import data
import torch
import torch.nn as nn
import numpy as np
import scipy.sparse
import scipy.io
from sklearn.metrics import roc_auc_score , accuracy_score, f1_score, precision_score, recall_score
from datetime import datetime
import argparse
import time
from model import Dominant
from utils import load_anomaly_detection_dataset
class RUNNER:
def __init__(self,adj,features,labels,dim=20,lr=5e-3,alpha=0.5,dropout=0.1,device='cuda',epochs=50):
self.alpha = alpha
self.epochs = epochs
self.adj, self.attrs, self.label, self.adj_label = load_anomaly_detection_dataset(adj,features,labels)
self.model = Dominant(feat_size = self.attrs.size(1), hidden_size = dim, dropout = dropout)
if device == 'cuda':
device = torch.device(device)
self.adj = self.adj.to(device)
self.adj_label = self.adj_label.to(device)
self.attrs = self.attrs.to(device)
self.model = self.model.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr = lr)
def loss_func(self,adj, A_hat, attrs, X_hat, alpha):
# select only the nodes concerned by the split
# X_hat = X_hat[idx] #-----
# attrs = attrs[idx] #-----
# Attribute reconstruction loss
diff_attribute = torch.pow(X_hat - attrs, 2)
attribute_reconstruction_errors = torch.sqrt(torch.sum(diff_attribute, 1))
attribute_cost = torch.mean(attribute_reconstruction_errors)
# remove the nodes non concerned by the split
# mask = torch.zeros(A_hat.size(0), dtype=bool, device='cuda') #-------
# mask[idx] = True #------
# A_hat = A_hat[mask] #------
# A_hat = A_hat[:,mask] #------
# adj = adj[mask] #-------
# adj = adj[:,mask] #-------
# structure reconstruction loss
diff_structure = torch.pow(A_hat - adj, 2)
structure_reconstruction_errors = torch.sqrt(torch.sum(diff_structure, 1))
structure_cost = torch.mean(structure_reconstruction_errors)
cost = alpha * attribute_reconstruction_errors + (1-alpha) * structure_reconstruction_errors
return cost, structure_cost, attribute_cost
def fit(self,outlier_percentage,invert=False):
train_losses = []
best_losses = []
start = time.time()
for epoch in range(self.epochs):
self.model.train()
self.optimizer.zero_grad()
A_hat, X_hat = self.model(self.attrs, self.adj)
loss, struct_loss, feat_loss = self.loss_func(self.adj_label, A_hat, self.attrs, X_hat, self.alpha)
l = torch.mean(loss)
train_losses.append(l.item())
if(len(best_losses) == 0):
best_losses.append(train_losses[0])
elif (best_losses[-1] > train_losses[-1]):
best_losses.append(train_losses[-1])
else:
best_losses.append(best_losses[-1])
l.backward()
self.optimizer.step()
print("Epoch:", '%04d' % (epoch), "train_loss=", "{:.5f}".format(l.item()), "str_loss=", "{:.5f}".format(struct_loss.item()), "feat_loss=", "{:.5f}".format(feat_loss.item()))
score = loss.detach().cpu().numpy()
sorted_indices = np.argsort(score)
sorted_score = np.array(score)[sorted_indices]
sorted_label = np.array(self.label)[sorted_indices]
if invert:
sorted_score[:int(len(sorted_score)*(1-outlier_percentage))] = 1
sorted_score[int(len(sorted_score)*outlier_percentage):] = 0
else:
sorted_score[:int(len(sorted_score)*(1-outlier_percentage))] = 0
sorted_score[int(len(sorted_score)*outlier_percentage):] = 1
acc = accuracy_score(sorted_label,sorted_score)
f1 = f1_score(sorted_label,sorted_score)
rec = recall_score(sorted_label,sorted_score)
pre = precision_score(sorted_label,sorted_score)
auc = roc_auc_score(sorted_label, sorted_score)
print('Acc',acc,'f1',f1,'Rec',rec,'pre',pre,'Auc',auc)
# if(self.idx_val.shape[0] >0):
# self.model.eval()
# A_hat, X_hat = self.model(self.attrs, self.adj)
# loss, struct_loss, feat_loss = self.loss_func(self.adj_label, A_hat, self.attrs, X_hat, self.alpha,self.idx_val)
# score = loss.detach().cpu().numpy()
# sorted_indices = np.argsort(score)
# sorted_score = np.array(score)[sorted_indices]
# sorted_label = np.array(self.label)[sorted_indices]
# sorted_score[:int(len(sorted_score)/2)] = 1
# sorted_score[int(len(sorted_score)/2):] = 0
# print("VALIDATION:",'Acc',accuracy_score(sorted_label,sorted_score),'f1',f1_score(sorted_label,sorted_score),'Rec',recall_score(sorted_label,sorted_score),'pre',precision_score(sorted_label,sorted_score),'Auc', roc_auc_score(sorted_label, sorted_score))
# print('\n')
return {'train':train_losses,'best':best_losses,'val':[]},acc,f1,rec,pre,auc,time.time() - start
def predict(self,outlier_percentage,invert=False):
self.model.eval()
A_hat, X_hat = self.model(self.attrs, self.adj)
loss, struct_loss, feat_loss = self.loss_func(self.adj_label, A_hat, self.attrs, X_hat, self.alpha)
score = loss.detach().cpu().numpy()
sorted_indices = np.argsort(score)
sorted_score = np.array(score)[sorted_indices]
sorted_label = np.array(self.label)[sorted_indices]
if invert:
sorted_score[:int(len(sorted_score)*(1-outlier_percentage))] = 1
sorted_score[int(len(sorted_score)*outlier_percentage):] = 0
else:
sorted_score[:int(len(sorted_score)*(1-outlier_percentage))] = 0
sorted_score[int(len(sorted_score)*outlier_percentage):] = 1
print('\n')
print('\n')
acc = accuracy_score(sorted_label,sorted_score)
f1 = f1_score(sorted_label,sorted_score)
rec = recall_score(sorted_label,sorted_score)
pre = precision_score(sorted_label,sorted_score)
auc = roc_auc_score(sorted_label, sorted_score)
return np.mean(score),acc,f1,rec,pre,auc
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import os
import math
from runner import RUNNER
import pickle
def load_samples(dataset,r):
adj_path = os.path.join(os.getcwd(),'datasets',dataset,f'adjacency_{r}.pkl')
features_path = os.path.join(os.getcwd(),'datasets',dataset,f'features_{r}.pkl')
labels_path = os.path.join(os.getcwd(),'datasets',dataset,f'labels_{r}.pkl')
with open(adj_path,'rb') as f:
adj = pickle.load(f)
with open(features_path,'rb') as f:
features = pickle.load(f)
with open(labels_path,'rb') as f:
labels = pickle.load(f)
print('features:',features.shape,'adj',adj.shape,'labels',labels.shape)
return adj,features,labels
def testing(NAME,TEST_PROP,EPOCHS,EMBEDDING_DIMS,REPETITIONS,DATASET):
statistics = {}
for dim in EMBEDDING_DIMS:
print(f'START DIM {dim} ***************************')
dim_statistics = []
for r in range(REPETITIONS):
torch.cuda.empty_cache()
adj,features,labels = load_samples(DATASET,r)
model = None
model = RUNNER(adj=adj,features=features,labels=labels,dim=dim,epochs=EPOCHS,lr=1e-3,alpha=0.95)
print(f"START REPETITION: {NAME} -------------------------------")
loss_training,acc_training,f1_training,recall_training,precision_training,roc_auc_training,training_time = model.fit()
loss_testing,acc_testing,f1_testing,recall_testing,precision_testing,roc_auc_testing = model.predict()
stats = {'train':{'loss':loss_training,'acc':acc_training,'f1':f1_training,'recall':recall_training,'precision':precision_training,'roc_auc':roc_auc_training,'time':training_time},
'test':{'loss':loss_testing,'acc':acc_testing,'f1':f1_testing,'recall':recall_testing,'precision':precision_testing,'roc_auc':roc_auc_testing}}
dim_statistics.append(stats)
del model
statistics[f'{dim}'] = dim_statistics
print('THE TESTING IS OVER !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
return statistics
def calculate_avg_std(stats,list_scores,etiq):
avgs = []
stds = []
for score in list_scores:
all_values = []
for repetition_stats in stats:
all_values.append(repetition_stats[etiq][score])
avgs.append(np.mean(all_values))
stds.append(np.std(all_values))
return avgs,stds
def print_stats_per_dim(list_dim,stats,name,repetition):
list_scores_train = ['acc','f1','recall','precision','roc_auc','time']
list_scores_test = ['acc','f1','recall','precision','roc_auc']
df = pd.DataFrame({})
for i in range(len(list_dim)):
dim = list_dim[i]
dstats = stats[str(dim)]
print("THE STATISTICS OF THE DIMENSION {}:".format(dim))
print('',"TESING:","---------",sep='\n')
avgs,stds = calculate_avg_std(dstats,list_scores_test,'test')
for i in range(len(list_scores_test)):
print(f"{list_scores_test[i]}: avg={avgs[i]}, std={stds[i]}")
df[f'{dim}_{list_scores_test[i]}_avg'] = [round(avgs[i],4)]
df[f'{dim}_{list_scores_test[i]}_ci'] = [round(1.96*stds[i]/math.sqrt(repetition),4)]
df[f'{dim}_{list_scores_test[i]}_std'] = [round(stds[i],4)]
print('',"********************************","",sep='\n')
print('',"TRAINING:","---------",sep='\n')
avgs,stds = calculate_avg_std(dstats,list_scores_train,'train')
for i in range(len(list_scores_train)):
print(f"{list_scores_train[i]}: avg={avgs[i]}, std={stds[i]}")
df[f'{dim}_{list_scores_train[-1]}_avg'] = [round(avgs[-1],4)]
df[f'{dim}_{list_scores_train[-1]}_ci'] = [round(1.96*stds[-1]/math.sqrt(repetition),4)]
df[f'{dim}_{list_scores_train[-1]}_std'] = [round(stds[-1],4)]
path = os.path.join(os.getcwd(),'results',name,f'{name}_statistics.csv')
df.to_csv(path, index=False)
def plot_dim_loss(dim,stats,epochs,name,etiq):
x = range(epochs)
for i in range(len(stats)):
y_real = stats[i]['train']['loss'][etiq]
plt.plot(x, y_real,label=f'rep {i}')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend()
path = os.path.join(os.getcwd(),'results',name,f'{name}_{dim}_{etiq}.png')
plt.savefig(path)
plt.close()
def plot_dim_loss_train_val(dim,stats,epochs,name,etiq):
x = range(epochs)
y_real = stats[-1]['train']['loss'][etiq]
y_val = stats[-1]['train']['loss']['val']
plt.plot(x, y_real,label=f'train')
plt.plot(x, y_val,label=f'val')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend()
path = os.path.join(os.getcwd(),'results',name,f'{name}_{dim}_{etiq}_val_train.png')
plt.savefig(path)
plt.close()
def plot_loss(statistics,name,dims,epochs):
for dim in dims:
plot_dim_loss(dim,statistics[str(dim)],epochs,name,'train')
plot_dim_loss(dim,statistics[str(dim)],epochs,name,'best')
# plot_dim_loss_train_val(dim,statistics[str(dim)],epochs,name,'train')
# plot_dim_loss_train_val(dim,statistics[str(dim)],epochs,name,'best')
def do_it():
EPOCHS = 50
REPETITIONS = 5
TEST_PROP = .3
EMBEDDING_DIMS= [2,5,10,20]
NAME = 'DOMINANT'
DATASET = 'ton_iot'
statistics = testing(NAME,TEST_PROP,EPOCHS,EMBEDDING_DIMS,REPETITIONS,DATASET=DATASET)
print_stats_per_dim(EMBEDDING_DIMS,statistics,name=NAME,repetition=REPETITIONS)
plot_loss(statistics,NAME,EMBEDDING_DIMS,EPOCHS)
do_it()
\ No newline at end of file
import numpy as np
import scipy.sparse as sp
import torch
import scipy.io as sio
import random
from sklearn.preprocessing import StandardScaler
import pickle
import os
def load_anomaly_detection_dataset(adj,feat,truth):
adj_norm = normalize_adj(adj)
adj_norm = adj_norm.toarray()
adj_norm = torch.FloatTensor(adj_norm)
adj = torch.FloatTensor(adj)
feat = torch.FloatTensor(feat)
return adj_norm, feat, truth, adj
def split_data(labels, val_prop, test_prop, seed):
np.random.seed(seed)
pos_idx = labels.nonzero()[0]
neg_idx = (1. - labels).nonzero()[0]
np.random.shuffle(pos_idx)
np.random.shuffle(neg_idx)
pos_idx = pos_idx.tolist()
neg_idx = neg_idx.tolist()
nb_pos_neg = min(len(pos_idx), len(neg_idx))
nb_val = round(val_prop * nb_pos_neg)
nb_test = round(test_prop * nb_pos_neg)
idx_val_pos, idx_test_pos, idx_train_pos = pos_idx[:nb_val], pos_idx[nb_val:nb_val + nb_test], pos_idx[
nb_val + nb_test:]
idx_val_neg, idx_test_neg, idx_train_neg = neg_idx[:nb_val], neg_idx[nb_val:nb_val + nb_test], neg_idx[
nb_val + nb_test:]
return idx_val_pos + idx_val_neg, idx_test_pos + idx_test_neg, idx_train_pos + idx_train_neg
def normalize_adj(adj):
"""Symmetrically normalize adjacency matrix."""
adj = sp.coo_matrix(adj)
rowsum = np.array(adj.sum(1))
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
\ No newline at end of file
import argparse
import torch
import os
parser = argparse.ArgumentParser(description='HGWaveNet')
parser.add_argument('--nfeat', type=int, default=75, help='dim of input feature')
parser.add_argument('--nhid', type=int, default=75, help='dim of hidden embedding')
parser.add_argument('--nout', type=int, default=2, help='dim of output embedding')
parser.add_argument('--num_nodes', type=int, default=1000, help='number of nodes per graph')
parser.add_argument('--nclasses', type=int, default=2, help='number of classes')
parser.add_argument('--lr', type=float, default=0.5, help='learning rate')
parser.add_argument('--max_epoch', type=int, default=20, help='number of epochs to train.')
parser.add_argument('--patience', type=int, default=20, help='patience for early stop')
parser.add_argument('--min_epoch', type=int, default=1, help='min epoch')
parser.add_argument('--weight_decay', type=float, default=0.01, help='weight for L2 loss on basic model.')
parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate (1 - keep probability).')
parser.add_argument('--heads', type=int, default=1, help='attention heads.')
parser.add_argument('--curvature', type=float, default=1.0, help='curvature value')
parser.add_argument('--trainable_curvature', type=bool, default=False, help='trainable curvature or not')
parser.add_argument('--aggregation', type=str, default='att', help='aggregation method: [deg, att]')
parser.add_argument('--timelength', type=int, default=26, help='total number of snapshots')
parser.add_argument('--testlength', type=int, default=.3, help='number of test snapshots')
parser.add_argument('--dataset', type=str, default='dblp', help='dataset name')
parser.add_argument('--data_pt_path', type=str, default='./data/', help='parent path of dataset')
parser.add_argument('--device', type=int, default=0, help='gpu id, -1 for cpu')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--repeat', type=int, default=1, help='running times')
parser.add_argument('--sampling_times', type=int, default=1, help='negative sampling times')
parser.add_argument('--log_interval', type=int, default=1, help='log interval, default: 20,[20,40,...]')
parser.add_argument('--pre_defined_feature', default=None, help='pre-defined node feature')
parser.add_argument('--save_embeddings', type=int, default=0, help='save or not, default:0')
parser.add_argument('--output_pt_path', type=str, default='./output/', help='parent path of output')
parser.add_argument('--debug_mode', type=int, default=0, help='debug_mode, 0: normal running; 1: debugging mode')
parser.add_argument('--use_riemannian_adam', type=bool, default=True,
help='use riemannian adam or original adam as optimizer')
parser.add_argument('--model', type=str, default='HGWaveNet', help='model name')
parser.add_argument('--manifold', type=str, default='PoincareBall', help='hyperbolic model')
parser.add_argument('--use_hyperdecoder', type=bool, default=True, help='use hyperbolic decoder or not')
parser.add_argument('--spatial_dilated_factors', type=list, default=[1, 2],
help='dilated factor for dilated spatial convolution')
parser.add_argument('--casual_conv_depth', type=int, default=3, help='number of temporal casual convolution layers')
parser.add_argument('--casual_conv_kernel_size', type=int, default=2,
help='temporal casual convolution kernel size')
parser.add_argument('--eps', type=float, default=1e-15, help='eps')
parser.add_argument('--bias', type=bool, default=True, help='use bias or not')
parser.add_argument('--trainable_feat', type=bool, default=False,
help='using trainable feat or one-hot feat, default: trainable feat')
args = parser.parse_args()
if args.device >= 0 and torch.cuda.is_available():
args.device = torch.device('cuda:{}'.format(args.device))
else:
args.device = torch.device('cpu')
print('Using device {} to train the model ...'.format(args.device))
args.output_path = os.path.join(args.output_pt_path, args.dataset)
if not os.path.isdir(args.output_path):
os.makedirs(args.output_path)
args.log_file = os.path.join(args.output_path, '{}.log'.format(args.model))
args.emb_file = os.path.join(args.output_path, '{}.emb'.format(args.model))
from .layers import Linear, FermiDiracDecoder
from .lorentz_layers import LorentzLinear, LorentzGraphNeuralNetwork, LorentzGraphDecoder
from .poincare_layers import HGCNConv, HGATConv, HypLinear, HypConv1d
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
def get_dim_act(args):
"""
Helper function to get dimension and activation at every layer.
:param args:
:return:
"""
if not args.act:
act = lambda x: x
else:
act = getattr(F, args.act)
acts = [act] * (args.num_layers - 1)
dims = [args.feat_dim] + ([args.dim] * (args.num_layers - 1))
if args.task in ['lp', 'rec']:
dims += [args.dim]
acts += [act]
return dims, acts
class Linear(Module):
"""
Simple Linear layer with dropout.
"""
def __init__(self, in_features, out_features, dropout, act, use_bias):
super(Linear, self).__init__()
self.dropout = dropout
self.linear = nn.Linear(in_features, out_features, use_bias)
self.act = act
def forward(self, x):
hidden = self.linear.forward(x)
hidden = F.dropout(hidden, self.dropout, training=self.training)
out = self.act(hidden)
return out
def reset_parameters(self):
self.linear.reset_parameters()
print('reset Euclidean defined Linear')
class FermiDiracDecoder(Module):
"""Fermi Dirac to compute edge probabilities based on distances."""
def __init__(self, r, t):
super(FermiDiracDecoder, self).__init__()
self.r = r
self.t = t
def forward(self, dist):
probs = 1. / (torch.exp((dist - self.r) / self.t) + 1)
return probs
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.modules.module import Module
from layers.layers import Linear
def get_dim_act_curv(args):
"""
get dimension and activation in each layer
:param args:
:return:
"""
if not args.act:
act = lambda x: x
else:
act = getattr(F, args.act)
acts = [act] * (args.num_layers - 1)
dims = [args.feat_dim] + ([args.dim] * (args.num_layers - 1))
if args.task in ['lp']:
dims += [args.dim]
acts += [act]
n_curvatures = args.num_layers
else:
n_curvatures = args.num_layers - 1
if args.c is None:
curvatures = [nn.Parameter(torch.Tensor([1.])) for _ in range(n_curvatures)]
else:
curvatures = [torch.tensor([args.c]) for _ in range(n_curvatures)]
if not args.cuda == -1:
curvatures = [curv.to(args.device) for curv in curvatures]
return dims, acts, curvatures
class LorentzLinear(nn.Module):
"""
Lorentz Hyperbolic Graph Neural Layer
"""
def __init__(self, manifold, in_features, out_features, c, drop_out, use_bias):
super(LorentzLinear, self).__init__()
self.manifold = manifold
self.in_features = in_features
self.out_features = out_features
self.c = c
self.drop_out = drop_out
self.use_bias = use_bias
self.bias = nn.Parameter(torch.Tensor(out_features - 1)) # -1 when use mine mat-vec multiply
self.weight = nn.Parameter(torch.Tensor(out_features - 1, in_features)) # -1, 0 when use mine mat-vec multiply
self.reset_parameters()
def report_weight(self):
print(self.weight)
def reset_parameters(self):
init.xavier_uniform_(self.weight, gain=math.sqrt(2))
init.constant_(self.bias, 0)
# print('reset lorentz linear layer')
def forward(self, x):
drop_weight = F.dropout(self.weight, self.drop_out, training=self.training)
mv = self.manifold.matvec_regular(drop_weight, x, self.bias, self.c, self.use_bias)
return mv
def extra_repr(self):
return 'in_features={}, out_features={}, c={}'.format(
self.in_features, self.out_features, self.c
)
class LorentzAgg(Module):
"""
Lorentz centroids aggregation layer
"""
def __init__(self, manifold, c, use_att, in_features, dropout):
super(LorentzAgg, self).__init__()
self.manifold = manifold
self.c = c
self.use_att = use_att
self.in_features = in_features
self.dropout = dropout
self.this_spmm = SpecialSpmm()
if use_att:
self.att = LorentzSparseSqDisAtt(manifold, c, in_features - 1, dropout)
def lorentz_centroid(self, weight, x, c):
"""
Lorentz centroid
:param weight: dense weight matrix. shape: [num_nodes, num_nodes]
:param x: feature matrix [num_nodes, features]
:param c: parameter of curvature
:return: the centroids of nodes [num_nodes, features]
"""
if self.use_att:
sum_x = self.this_spmm(weight[0], weight[1], weight[2], x)
else:
sum_x = torch.spmm(weight, x)
x_inner = self.manifold.l_inner(sum_x, sum_x)
coefficient = (c ** 0.5) / torch.sqrt(torch.abs(x_inner))
return torch.mul(coefficient, sum_x.transpose(-2, -1)).transpose(-2, -1)
def forward(self, x, adj):
if self.use_att:
adj = self.att(x, adj)
output = self.lorentz_centroid(adj, x, self.c)
return output
def extra_repr(self):
return 'c={}, use_att={}'.format(
self.c, self.use_att
)
def reset_parameters(self):
if self.use_att:
self.att.reset_parameters()
# print('reset agg finished')
class LorentzAct(Module):
"""
Lorentz activation layer
"""
def __init__(self, manifold, c_in, c_out, act):
super(LorentzAct, self).__init__()
self.manifold = manifold
self.c_in = c_in
self.c_out = c_out
self.act = act
def forward(self, x):
xt = self.act(self.manifold.log_map_zero(x, c=self.c_in))
xt = self.manifold.normalize_tangent_zero(xt, self.c_in)
return self.manifold.exp_map_zero(xt, c=self.c_out)
def extra_repr(self):
return 'c_in={}, c_out={}'.format(
self.c_in, self.c_out
)
class LorentzGraphNeuralNetwork(nn.Module):
def __init__(self, manifold, in_feature, out_features, c_in, c_out, drop_out, act, use_bias, use_att):
super(LorentzGraphNeuralNetwork, self).__init__()
self.manifold = manifold
self.c_in = c_in
self.c_out = c_out
self.linear = LorentzLinear(manifold, in_feature, out_features, c_in, drop_out, use_bias)
self.agg = LorentzAgg(manifold, c_in, use_att, out_features, drop_out)
self.lorentz_act = LorentzAct(manifold, c_in, c_out, act)
self.ll = Linear(2 * out_features, out_features, drop_out, act, use_bias)
def forward(self, _input):
x, adj = _input
h = self.linear.forward(x)
h = self.agg.forward(h, adj)
h = self.lorentz_act.forward(h)
output = h, adj
return output
def reset_parameters(self):
self.linear.reset_parameters()
self.agg.reset_parameters()
class SpecialSpmmFunction(torch.autograd.Function):
"""Special function for only sparse region backpropataion layer."""
@staticmethod
def forward(ctx, indices, values, shape, b):
assert indices.requires_grad is False
device = b.device
a = torch.sparse_coo_tensor(indices, values, shape, device=device)
ctx.save_for_backward(a, b)
ctx.N = shape[0]
return torch.matmul(a, b)
@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_tensors
grad_values = grad_b = None
if ctx.needs_input_grad[1]:
grad_a_dense = grad_output.matmul(b.t())
edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
grad_values = grad_a_dense.view(-1)[edge_idx]
if ctx.needs_input_grad[3]:
grad_b = a.t().matmul(grad_output)
return None, grad_values, None, grad_b
class SpecialSpmm(nn.Module):
@staticmethod
def forward(indices, values, shape, b):
return SpecialSpmmFunction.apply(indices, values, shape, b)
class LorentzSparseSqDisAtt(nn.Module):
def __init__(self, manifold, c, in_features, dropout):
super(LorentzSparseSqDisAtt, self).__init__()
self.dropout = dropout
self.in_features = in_features
self.manifold = manifold
self.c = c
self.weight_linear = LorentzLinear(manifold, in_features, in_features + 1, c, dropout, True)
def forward(self, x, adj):
d = x.size(1) - 1
x = self.weight_linear(x)
index = adj._indices()
_x = x[index[0, :]]
_y = x[index[1, :]]
_x_head = _x.narrow(1, 0, 1)
_y_head = _y.narrow(1, 0, 1)
_x_tail = _x.narrow(1, 1, d)
_y_tail = _y.narrow(1, 1, d)
l_inner = -_x_head.mul(_y_head).sum(-1) + _x_tail.mul(_y_tail).sum(-1)
res = torch.clamp(-(self.c + l_inner), min=1e-10, max=1)
res = torch.exp(-res)
return index, res, adj.size()
class LorentzGraphDecoder(nn.Module):
"""
Lorentzian graph neural network decoder
"""
def __init__(self, manifold, in_feature, out_features, c_in, c_out, drop_out, act, use_bias, use_att):
super(LorentzGraphDecoder, self).__init__()
self.manifold = manifold
self.c_in = c_in
self.out_features = out_features + 1 # original output equal to num_classes
self.in_features = in_feature
self.linear = LorentzLinear(manifold, in_feature - 1, self.out_features, c_in, drop_out, False)
self.agg = LorentzAgg(manifold, c_in, use_att, self.out_features, drop_out)
self.lorentz_act = LorentzAct(manifold, c_in, c_out, act)
self.bias = nn.Parameter(torch.Tensor(self.out_features)) if use_bias else None
init.constant_(self.bias, 0)
def forward(self, _input):
x, adj = _input
# print('=====x', x.shape, self.in_features)
h = self.linear.forward(x) # problem is h1+
h = self.agg.forward(h, adj)
h = self.lorentz_act.forward(h)
# b = self.manifold.ptransp0(h, self.bias, self.c_in)
# b = self.manifold.exp_map_x(h, b, self.c_in)
poincare_h = self.manifold.lorentz2poincare(h, self.c_in)
output = poincare_h, adj
return output
def reset_parameters(self):
self.linear.reset_parameters()
self.agg.reset_parameters()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch_geometric.utils import add_remaining_self_loops, remove_self_loops, softmax, add_self_loops
from torch_scatter import scatter, scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch.nn.parameter import Parameter
from torch_geometric.nn.inits import glorot, zeros
class HGATConv(nn.Module):
"""
Poincare graph convolution layer.
"""
def __init__(self, manifold, in_features, out_features, c_in, c_out, device, act=F.leaky_relu,
dropout=0.6, att_dropout=0.6, use_bias=True, heads=2, concat=False):
super(HGATConv, self).__init__()
out_features = out_features * heads
self.linear = HypLinear(manifold, in_features, out_features, c_in, device, dropout=dropout, use_bias=use_bias)
self.agg = HypAttAgg(manifold, c_in, out_features, device, att_dropout, heads=heads, concat=concat)
self.hyp_act = HypAct(manifold, c_in, c_out, act)
self.manifold = manifold
self.c_in = c_in
self.c_out = c_out
self.device = device
def forward(self, x, edge_index):
h = self.linear.forward(x)
h = self.agg.forward(h, edge_index)
h = self.hyp_act.forward(h)
return h
class HGCNConv(nn.Module):
"""
Poincare graph convolution layer, from HGCN。
"""
def __init__(self, manifold, in_features, out_features, device, c_in=1.0, c_out=1.0, dropout=0.6, act=F.leaky_relu,
use_bias=True):
super(HGCNConv, self).__init__()
self.linear = HypLinear(manifold, in_features, out_features, c_in, device, dropout=dropout, use_bias=use_bias)
self.agg = HypAgg(manifold, c_in, out_features, device, bias=use_bias)
self.hyp_act = HypAct(manifold, c_in, c_out, act)
self.manifold = manifold
self.c_in = c_in
self.device = device
def forward(self, x, edge_index):
h = self.linear.forward(x)
h = self.agg.forward(h, edge_index)
h = self.hyp_act.forward(h)
return h
class HypLinear(nn.Module):
"""
Poincare linear layer.
"""
def __init__(self, manifold, in_features, out_features, c, device, dropout=0.6, use_bias=True):
super(HypLinear, self).__init__()
self.manifold = manifold
self.in_features = in_features
self.out_features = out_features
self.c = c
self.device = device
self.dropout = dropout
self.use_bias = use_bias
self.bias = Parameter(torch.Tensor(out_features).to(device), requires_grad=True)
self.weight = Parameter(torch.Tensor(out_features, in_features).to(device), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
def forward(self, x):
drop_weight = F.dropout(self.weight, p=self.dropout, training=self.training)
mv = self.manifold.mobius_matvec(drop_weight, x, self.c)
res = self.manifold.proj(mv, self.c)
if self.use_bias:
bias = self.manifold.proj_tan0(self.bias.view(1, -1), self.c)
hyp_bias = self.manifold.expmap0(bias, self.c)
hyp_bias = self.manifold.proj(hyp_bias, self.c)
res = self.manifold.mobius_add(res, hyp_bias, c=self.c)
res = self.manifold.proj(res, self.c)
return res
def extra_repr(self):
return 'in_features={}, out_features={}, c={}'.format(
self.in_features, self.out_features, self.c
)
class HypAttAgg(MessagePassing):
def __init__(self, manifold, c, out_features, device, att_dropout=0.6, heads=1, concat=False):
super(HypAttAgg, self).__init__()
self.manifold = manifold
self.dropout = att_dropout
self.out_channels = out_features // heads
self.negative_slope = 0.2
self.heads = heads
self.c = c
self.device = device
self.concat = concat
self.att_i = Parameter(torch.Tensor(1, heads, self.out_channels).to(device), requires_grad=True)
self.att_j = Parameter(torch.Tensor(1, heads, self.out_channels).to(device), requires_grad=True)
glorot(self.att_i)
glorot(self.att_j)
def forward(self, x, edge_index):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index,
num_nodes=x.size(self.node_dim))
edge_index_i = edge_index[0]
edge_index_j = edge_index[1]
x_tangent0 = self.manifold.logmap0(x, c=self.c) # project to origin
x_i = torch.nn.functional.embedding(edge_index_i, x_tangent0)
x_j = torch.nn.functional.embedding(edge_index_j, x_tangent0)
x_i = x_i.view(-1, self.heads, self.out_channels)
x_j = x_j.view(-1, self.heads, self.out_channels)
alpha = (x_i * self.att_i).sum(-1) + (x_j * self.att_j).sum(-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, edge_index_i, num_nodes=x_i.size(0))
alpha = F.dropout(alpha, self.dropout, training=self.training)
support_t = scatter(x_j * alpha.view(-1, self.heads, 1), edge_index_i, dim=0)
if self.concat:
support_t = support_t.view(-1, self.heads * self.out_channels)
else:
support_t = support_t.mean(dim=1)
support_t = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c)
return support_t
class HypAct(Module):
"""
Poincare activation layer.
"""
def __init__(self, manifold, c_in, c_out, act):
super(HypAct, self).__init__()
self.manifold = manifold
self.c_in = c_in
self.c_out = c_out
self.act = act
def forward(self, x):
xt = self.act(self.manifold.logmap0(x, c=self.c_in))
xt = self.manifold.proj_tan0(xt, c=self.c_out)
return self.manifold.proj(self.manifold.expmap0(xt, c=self.c_out), c=self.c_out)
def extra_repr(self):
return 'c_in={}, c_out={}'.format(
self.c_in, self.c_out
)
class HypAgg(MessagePassing):
"""
Poincare aggregation layer using degree.
"""
def __init__(self, manifold, c, out_features, device, bias=True):
super(HypAgg, self).__init__()
self.manifold = manifold
self.c = c
self.device = device
self.use_bias = bias
if bias:
self.bias = Parameter(torch.Tensor(out_features).to(device))
else:
self.register_parameter('bias', None)
zeros(self.bias)
self.mlp = nn.Sequential(nn.Linear(out_features * 2, 1).to(device))
@staticmethod
def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None):
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
device=edge_index.device)
fill_value = 1 if not improved else 2
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index=None):
x_tangent = self.manifold.logmap0(x, c=self.c)
edge_index, norm = self.norm(edge_index, x.size(0), dtype=x.dtype)
node_i = edge_index[0]
node_j = edge_index[1]
x_j = torch.nn.functional.embedding(node_j, x_tangent)
support = norm.view(-1, 1) * x_j
support_t = scatter(support, node_i, dim=0, dim_size=x.size(0)) # aggregate the neighbors of node_i
output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c)
return output
def extra_repr(self):
return 'c={}'.format(self.c)
class HypConv1d(nn.Module):
def __init__(self, manifold, in_size, out_size, kernel_size, c, device, dilation=1, stride=1):
super(HypConv1d, self).__init__()
self.manifold = manifold
self.in_size = in_size
self.out_size = out_size
self.kernel_size = kernel_size
self.c = c
self.device = device
self.dilation = dilation
self.stride = stride
self.pad = (kernel_size - 1) // 2 * dilation
self.conv = nn.Conv1d(in_size, out_size, kernel_size, padding=self.pad,
stride=stride, dilation=dilation, device=device)
self.reset_parameters()
def reset_parameters(self):
glorot(self.conv.weight)
def to_tangent(self, x, c=1.0):
x_tan = self.manifold.logmap0(x, c)
x_tan = self.manifold.proj_tan0(x_tan, c)
return x_tan
def to_hyper(self, x, c=1.0):
x_tan = self.manifold.proj_tan0(x, c)
x_hyp = self.manifold.expmap0(x_tan, c)
x_hyp = self.manifold.proj(x_hyp, c)
return x_hyp
def forward(self, x):
x = self.to_tangent(self.manifold.proj(x, self.c), self.c)
x = x.permute(1, 2, 0)
x = self.conv(x)
x = x.permute(2, 0, 1)
x = self.to_hyper(x, self.c)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score, average_precision_score
from manifolds import Euclidean, Lorentzian, PoincareBall
from sklearn.metrics import accuracy_score, f1_score, recall_score,precision_score,roc_auc_score
class ReconLoss(nn.Module):
def __init__(self, args, c,model):
super(ReconLoss, self).__init__()
if args.manifold == 'PoincareBall':
self.manifold = PoincareBall()
elif args.manifold == 'Lorentzian':
self.manifold = Lorentzian()
elif args.manifold == 'Euclidean':
self.manifold = Euclidean()
else:
raise RuntimeError('invalid argument: manifold')
self.device = args.device
self.c = c
self.eps = args.eps
self.negative_sampling = negative_sampling
self.sampling_times = args.sampling_times
self.model = model
#self.fermidirac_decoder = FermiDiracDecoder(2.0, 1.0)
#self.use_hyperdecoder = args.use_hyperdecoder and (not isinstance(self.manifold, Euclidean))
# @staticmethod
# def decoder(z, edge_index):
# value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)
# return torch.sigmoid(value)
# def hyperdecoder(self, z, edge_index):
# edge_i = edge_index[0]
# edge_j = edge_index[1]
# z_i = F.embedding(edge_i, z)
# z_j = F.embedding(edge_j, z)
# dist = self.manifold.sqdist(z_i, z_j, self.c).squeeze()
# return self.fermidirac_decoder(dist)
# def forward(self, z, pos_edge_index, neg_edge_index=None):
# decoder = self.hyperdecoder if self.use_hyperdecoder else self.decoder
# pos_loss = -torch.log(decoder(z, pos_edge_index) + self.eps).mean()
# if neg_edge_index is None:
# neg_edge_index = negative_sampling(pos_edge_index,
# num_neg_samples=pos_edge_index.size(1) * self.sampling_times)
# neg_loss = -torch.log(1 - decoder(z, neg_edge_index) + self.eps).mean()
# return pos_loss + neg_loss
# def predict(self, z, pos_edge_index, neg_edge_index):
# decoder = self.hyperdecoder if self.use_hyperdecoder else self.decoder
# pos_y = z.new_ones(pos_edge_index.size(1)).to(self.device)
# neg_y = z.new_zeros(neg_edge_index.size(1)).to(self.device)
# y = torch.cat([pos_y, neg_y], dim=0)
# pos_pred = decoder(z, pos_edge_index)
# neg_pred = decoder(z, neg_edge_index)
# pred = torch.cat([pos_pred, neg_pred], dim=0)
# y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()
# return roc_auc_score(y, pred), average_precision_score(y, pred)
# class FermiDiracDecoder(nn.Module):
# """Fermi Dirac to compute edge probabilities based on distances."""
# def __init__(self, r, t):
# super(FermiDiracDecoder, self).__init__()
# self.r = r
# self.t = t
# def forward(self, dist):
# probs = 1. / (torch.exp((dist - self.r) / self.t) + 1.0)
# return probs
def forward(self, z,labels):
output = self.model.decode(z)
return F.nll_loss(output, labels)
\ No newline at end of file
import torch
import numpy as np
import time
import geoopt
import networkx as nx
from math import isnan
from config import args
from utils.data_utils import prepare
from utils.util import set_random, logger
from model import HGWaveNet
from loss import ReconLoss
from sklearn.metrics import accuracy_score, f1_score, recall_score,precision_score,roc_auc_score
class Trainer(object):
def __init__(self):
args.num_nodes = args.num_nodes
self.train_shots = list(range(0, args.timelength - int(args.testlength * args.timelength)))
self.test_shots = list(range(int(args.testlength * args.timelength), args.timelength))
self.model = HGWaveNet(args).to(args.device)
self.loss = ReconLoss(args, self.model.c_out,self.model)
self.optimizer = geoopt.optim.radam.RiemannianAdam(self.model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
set_random(args.seed)
def calculate_metrics(self,z,labels):
preds = self.model.decode(z)
labels = labels.cpu().detach().numpy()
preds = preds.cpu().detach().numpy()
preds = np.argmax(preds,axis=1)
f1 = f1_score(labels, preds)
accuracy = accuracy_score(labels, preds)
recall = recall_score(labels, preds)
precision = precision_score(labels, preds)
roc_auc = roc_auc_score(labels,preds )
return f1,accuracy,precision,recall,roc_auc
def train(self):
t_total = time.time()
min_loss = 1.0e8
patience = 0
for epoch in range(1, args.max_epoch + 1):
t_epoch = time.time()
epoch_losses = []
z = None
epoch_loss = None
self.model.init_history()
self.model.train()
for t in self.train_shots:
dilated_edge_index,features,labels = prepare(t,args.spatial_dilated_factors,args.device)
self.optimizer.zero_grad()
z = self.model(dilated_edge_index,features)
epoch_loss = self.loss(z, labels) + self.model.htc(z)
logger.info('Epoch:{}, Snapshot: {}; Loss: {:.4f}'.format(epoch, t, epoch_loss.item()))
epoch_loss.backward()
if isnan(epoch_loss):
logger.info('==' * 45)
logger.info('nan loss')
break
self.optimizer.step()
epoch_losses.append(epoch_loss.item())
self.model.update_history(z)
if isnan(epoch_loss):
break
gpu_mem_alloc = torch.cuda.max_memory_allocated() / 1000000 if torch.cuda.is_available() else 0
average_epoch_loss = np.mean(epoch_losses)
if average_epoch_loss < min_loss:
min_loss = average_epoch_loss
patience = 0
else:
patience += 1
if epoch > args.min_epoch and patience > args.patience:
logger.info('==' * 45)
logger.info('early stopping!')
break
if epoch == 1 or epoch % args.log_interval == 0:
test_results = self.test()
logger.info('==' * 45)
logger.info("Epoch:{}, Loss: {:.4f}, Time: {:.3f}, GPU: {:.1f}MiB".format(epoch, average_epoch_loss,
time.time() - t_epoch,
gpu_mem_alloc))
logger.info('Epoch:{}, Accuracy: {:.4f}; F1: {:.4f}; Recall: {:.4f}; Precision: {:.4f}; ROC AUC: {:.4f}, Memory: {:.4f}, time: {:.4f}'.format(epoch, test_results[1], test_results[0],test_results[2],test_results[3],test_results[4],test_results[5],test_results[6]))
logger.info('==' * 45)
logger.info('Total time: {:.3f}'.format(time.time() - t_total))
def test(self):
f1_list,acc_list,pre_list,rec_list,roc_list,occupied_memory,ptime = [], [], [], [], [],[],[]
self.model.eval()
for t in self.test_shots:
edge_index,features,labels = prepare(t,args.spatial_dilated_factors,args.device)
start = time.time()
embeddings = self.model(edge_index,features)
f1,accuracy,precision,recall,roc_auc = self.calculate_metrics(embeddings, labels)
ptime.append(time.time() - start)
occupied_memory.append(torch.cuda.max_memory_allocated() / 1000000 if torch.cuda.is_available() else 0)
f1_list.append(f1)
acc_list.append(accuracy)
pre_list.append(precision)
rec_list.append(recall)
roc_list.append(roc_auc)
return np.mean(f1_list), np.mean(acc_list), np.mean(rec_list), np.mean(pre_list), np.mean(roc_list), np.mean(occupied_memory),np.mean(ptime) * 1000
if __name__ == '__main__':
trainer = Trainer()
trainer.train()
print(trainer.test())
\ No newline at end of file
from .base import ManifoldParameter
from .euclidean import Euclidean
from .lorentzian import Lorentzian
from .poincare import PoincareBall
from torch.nn import Parameter
class Manifold(object):
"""
Abstract class to define operations on a manifold.
"""
def __init__(self):
super().__init__()
self.eps = 10e-8
def sqdist(self, p1, p2, c):
"""Squared distance between pairs of points."""
raise NotImplementedError
def dist(self, p1, p2, c):
"""Distance between a pair of points"""
raise NotImplementedError
def egrad2rgrad(self, p, dp, c):
"""Converts Euclidean Gradient to Riemannian Gradients."""
raise NotImplementedError
def proj(self, p, c):
"""Projects point p on the manifold."""
raise NotImplementedError
def proj_tan(self, u, p, c):
"""Projects u on the tangent space of p."""
raise NotImplementedError
def proj_tan0(self, u, c):
"""Projects u on the tangent space of the origin."""
raise NotImplementedError
def expmap(self, u, p, c):
"""Exponential map of u at point p."""
raise NotImplementedError
def logmap(self, p1, p2, c):
"""Logarithmic map of point p1 at point p2."""
raise NotImplementedError
def expmap0(self, u, c):
"""Exponential map of u at the origin."""
raise NotImplementedError
def logmap0(self, p, c):
"""Logarithmic map of point p at the origin."""
raise NotImplementedError
def mobius_add(self, x, y, c, dim=-1):
"""Adds points x and y."""
raise NotImplementedError
def mobius_matvec(self, m, x, c):
"""Performs hyperboic martrix-vector multiplication."""
raise NotImplementedError
def init_weights(self, w, c, irange=1e-5):
"""Initializes random weigths on the manifold."""
raise NotImplementedError
def inner(self, p, c, u, v=None):
"""Inner product for tangent vectors at point x."""
raise NotImplementedError
def ptransp(self, x, y, u, c):
"""Parallel transport of u from x to y."""
raise NotImplementedError
# the def defined by mine
def l_inner(self, x, y, keep_dim):
"""Lorentz inner"""
raise NotImplementedError
def induced_distance(self, x, y, c):
"""Metric distance"""
raise NotImplementedError
def lorentzian_distance(self, x, y, c):
"""lorzentzian distance"""
raise NotImplementedError
def exp_map_x(self, p, dp, c, is_res_normalize, is_dp_normalize):
raise NotImplementedError
def exp_map_zero(self, dp, c, is_res_normalize, is_dp_normalize):
raise NotImplementedError
def log_map_x(self, x, y, c, is_tan_normalize):
raise NotImplementedError
def log_map_zero(self, y, c, is_tan_normalize):
raise NotImplementedError
def matvec_proj(self, m, x, c):
raise NotImplementedError
def matvecbias_proj(self, m, x, b, c):
raise NotImplementedError
def matvec_regular(self, m, x, c):
raise NotImplementedError
def matvecbias_regular(self, m, x, b, c):
raise NotImplementedError
def normalize_tangent_zero(self, p_tan, c):
raise NotImplementedError
def lorentz_centroid(self, weight, x, c):
raise NotImplementedError
def normalize_input(self, x, c):
raise NotImplementedError
def normlize_tangent_bias(self, x, c):
raise NotImplementedError
def proj_tan_zero(self, u, c):
raise NotImplementedError
def lorentz2poincare(self, x, c):
raise NotImplementedError
def poincare2lorentz(self, x, c):
raise NotImplementedError
def _lambda_x(self, x, c):
raise NotImplementedError
class ManifoldParameter(Parameter):
"""
Subclass of torch.nn.Parameter for Riemannian optimization.
"""
def __new__(cls, data, requires_grad, manifold, c):
return Parameter.__new__(cls, data, requires_grad)
def __init__(self, data, requires_grad, manifold, c):
super().__init__(data, requires_grad)
self.c = c
self.manifold = manifold
def __repr__(self):
return '{} Parameter containing:\n'.format(self.manifold.name) + super(Parameter, self).__repr__()
from manifolds.base import Manifold
class Euclidean(Manifold):
"""
Euclidean Manifold class.
"""
def __init__(self):
super(Euclidean, self).__init__()
self.name = 'Euclidean'
@staticmethod
def normalize(p):
dim = p.size(-1)
p.view(-1, dim).renorm_(2, 0, 1.)
return p
def sqdist(self, p1, p2, c):
return (p1 - p2).pow(2).sum(dim=-1)
def egrad2rgrad(self, p, dp, c):
return dp
def proj(self, p, c):
return p
def proj_tan(self, u, p, c):
return u
def proj_tan0(self, u, c):
return u
def expmap(self, u, p, c):
return p + u
def logmap(self, p1, p2, c):
return p2 - p1
def expmap0(self, u, c):
return u
def logmap0(self, p, c):
return p
def mobius_add(self, x, y, c, dim=-1):
return x + y
def mobius_matvec(self, m, x, c):
mx = x @ m.transpose(-1, -2)
return mx
def init_weights(self, w, c, irange=1e-5):
w.data.uniform_(-irange, irange)
return w
def inner(self, p, c, u, v=None, keepdim=False):
if v is None:
v = u
return (u * v).sum(dim=-1, keepdim=keepdim)
def ptransp(self, x, y, v, c):
return v
import torch
from manifolds.base import Manifold
from utils.math_utils import arcosh
class Lorentzian(Manifold):
"""
Hyperboloid Manifold class.
for x in (d+1)-dimension Euclidean space
-x0^2 + x1^2 + x2^2 + … + xd = -c, x0 > 0, c > 0
negative curvature - 1 / c
"""
def __init__(self):
super(Lorentzian, self).__init__()
self.name = 'Lorentzian'
self.max_norm = 1000
self.min_norm = 1e-8
self.eps = {torch.float32: 1e-6, torch.float64: 1e-8}
def l_inner(self, x, y, keep_dim=False):
# input shape [node, features]
d = x.size(-1) - 1
xy = x * y
xy = torch.cat((-xy.narrow(1, 0, 1), xy.narrow(1, 1, d)), dim=1)
return torch.sum(xy, dim=1, keepdim=keep_dim)
def sqdist(self, p1, p2, c):
dist = self.lorentzian_distance(p1, p2, c)
dist = torch.clamp(dist, min=self.eps[p1.dtype], max=50)
return dist
def induced_distance(self, x, y, c):
xy_inner = self.l_inner(x, y)
sqrt_c = c ** 0.5
return sqrt_c * arcosh(-xy_inner / c + self.eps[x.dtype])
def lorentzian_distance(self, x, y, c):
# the squared Lorentzian distance
xy_inner = self.l_inner(x, y)
return -2 * (c + xy_inner)
def egrad2rgrad(self, p, dp, c):
"""
Transform the Euclidean gradient to Riemannian gradient
:param p: vector in hyperboloid
:param dp: gradient with Euclidean geometry
:param c: parameter of curvature
:return: gradient with Riemannian geometry
"""
dp.narrow(-1, 0, 1).mul_(-1) # multiply g_l^-1
dp.addcmul_(self.l_inner(p, dp, keep_dim=True).expand_as(p) / c, p)
return dp
def normalize(self, p, c):
"""
Normalize vector to confirm it is located on the hyperboloid
:param p: [nodes, features(d + 1)]
:param c: parameter of curvature
"""
d = p.size(-1) - 1
narrowed = p.narrow(-1, 1, d)
if self.max_norm:
narrowed = torch.renorm(narrowed.view(-1, d), 2, 0, self.max_norm)
first = c + torch.sum(torch.pow(narrowed, 2), dim=-1, keepdim=True)
first = torch.sqrt(first)
return torch.cat((first, narrowed), dim=1)
def proj(self, p, c):
return self.normalize(p, c)
def normalize_tangent(self, p, p_tan, c):
"""
Normalize tangent vectors to place the vectors satisfies <p, p_tan>_L=0
:param p: the tangent spaces at p. size:[nodes, feature]
:param p_tan: the tangent vector in tangent space at p
:param c: parameter of curvature
"""
d = p_tan.size(1) - 1
p_tail = p.narrow(1, 1, d)
p_tan_tail = p_tan.narrow(1, 1, d)
ptpt = torch.sum(p_tail * p_tan_tail, dim=1, keepdim=True)
p_head = torch.sqrt(c + torch.sum(torch.pow(p_tail, 2), dim=1, keepdim=True) + self.eps[p.dtype])
return torch.cat((ptpt / p_head, p_tan_tail), dim=1)
def normalize_tangent_zero(self, p_tan, c):
zeros = torch.zeros_like(p_tan)
zeros[:, 0] = c ** 0.5
return self.normalize_tangent(zeros, p_tan, c)
def exp_map_x(self, p, dp, c, is_res_normalize=True, is_dp_normalize=True):
if is_dp_normalize:
dp = self.normalize_tangent(p, dp, c)
dp_lnorm = self.l_inner(dp, dp, keep_dim=True)
dp_lnorm = torch.sqrt(torch.clamp(dp_lnorm + self.eps[p.dtype], 1e-6))
dp_lnorm_cut = torch.clamp(dp_lnorm, max=50)
sqrt_c = c ** 0.5
res = (torch.cosh(dp_lnorm_cut / sqrt_c) * p) + sqrt_c * (torch.sinh(dp_lnorm_cut / sqrt_c) * dp / dp_lnorm)
if is_res_normalize:
res = self.normalize(res, c)
return res
def exp_map_zero(self, dp, c, is_res_normalize=True, is_dp_normalize=True):
zeros = torch.zeros_like(dp)
zeros[:, 0] = c ** 0.5
return self.exp_map_x(zeros, dp, c, is_res_normalize, is_dp_normalize)
def log_map_x(self, x, y, c, is_tan_normalize=True):
"""
Logarithmic map at x: project hyperboloid vectors to a tangent space at x
:param x: vector on hyperboloid
:param y: vector to project a tangent space at x
:param c: parameter of curvature
:param is_tan_normalize: whether normalize the y_tangent
:return: y_tangent
"""
xy_distance = self.induced_distance(x, y, c)
tmp_vector = y + self.l_inner(x, y, keep_dim=True) / c * x
tmp_norm = torch.sqrt(self.l_inner(tmp_vector, tmp_vector) + self.eps[x.dtype])
y_tan = xy_distance.unsqueeze(-1) / tmp_norm.unsqueeze(-1) * tmp_vector
if is_tan_normalize:
y_tan = self.normalize_tangent(x, y_tan, c)
return y_tan
def log_map_zero(self, y, c, is_tan_normalize=True):
zeros = torch.zeros_like(y)
zeros[:, 0] = c ** 0.5
return self.log_map_x(zeros, y, c, is_tan_normalize)
def logmap0(self, p, c):
return self.log_map_zero(p, c)
def proj_tan(self, u, p, c):
"""
project vector u into the tangent vector at p
:param u: the vector in Euclidean space
:param p: the vector on a hyperboloid
:param c: parameter of curvature
"""
return u - self.l_inner(u, p, keep_dim=True) / self.l_inner(p, p, keep_dim=True) * p
def proj_tan_zero(self, u, c):
zeros = torch.zeros_like(u)
# print(zeros)
zeros[:, 0] = c ** 0.5
return self.proj_tan(u, zeros, c)
def proj_tan0(self, u, c):
return self.proj_tan_zero(u, c)
def normalize_input(self, x, c):
# print('=====normalize original input===========')
num_nodes = x.size(0)
zeros = torch.zeros(num_nodes, 1, dtype=x.dtype, device=x.device)
x_tan = torch.cat((zeros, x), dim=1)
return self.exp_map_zero(x_tan, c)
def matvec_regular(self, m, x, b, c, use_bias):
d = x.size(1) - 1
x_tan = self.log_map_zero(x, c)
x_head = x_tan.narrow(1, 0, 1)
x_tail = x_tan.narrow(1, 1, d)
mx = x_tail @ m.transpose(-1, -2)
if use_bias:
mx_b = mx + b
else:
mx_b = mx
mx = torch.cat((x_head, mx_b), dim=1)
mx = self.normalize_tangent_zero(mx, c)
mx = self.exp_map_zero(mx, c)
cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8)
res = torch.zeros(1, dtype=mx.dtype, device=mx.device)
res = torch.where(cond, res, mx)
return res
def lorentz_centroid(self, weight, x, c):
sum_x = torch.spmm(weight, x)
# print('weight x', sum_x)
x_inner = self.l_inner(sum_x, sum_x)
coefficient = (c ** 0.5) / torch.sqrt(torch.abs(x_inner))
return torch.mul(coefficient, sum_x.transpose(-2, -1)).transpose(-2, -1)
def lorentz2poincare(self, x, c):
try:
radius = torch.sqrt(c)
except:
radius = c ** 0.5
d = x.size(-1) - 1
return (x.narrow(-1, 1, d) * radius) / (x.narrow(-1, 0, 1) + radius)
def poincare2lorentz(self, x, c):
x_norm_square = torch.sum(x * x, dim=1, keepdim=True)
return torch.cat((1 + x_norm_square, 2 * x), dim=1) / (1 - x_norm_square + 1e-8)
def ptransp0(self, y, v, c):
# y: target point
zeros = torch.zeros_like(v)
zeros[:, 0] = c ** 0.5
v = self.normalize_tangent_zero(v, c)
return self.ptransp(zeros, y, v, c)
def ptransp(self, x, y, v, c):
# transport v from x to y
K = 1. / c
yv = self.l_inner(y, v, keep_dim=True)
xy = self.l_inner(x, y, keep_dim=True)
_frac = K * yv / (1 - K * xy)
return v + _frac * (x + y)
"""Poincare ball manifold."""
import torch
from manifolds.base import Manifold
from utils.math_utils import artanh, tanh
class PoincareBall(Manifold):
"""
PoicareBall Manifold class.
We use the following convention: x0^2 + x1^2 + ... + xd^2 < 1 / c
Note that 1/sqrt(c) is the Poincare ball radius.
"""
def __init__(self, ):
super(PoincareBall, self).__init__()
self.name = 'PoincareBall'
self.min_norm = 1e-15
self.eps = {torch.float32: 4e-3, torch.float64: 1e-5}
def sqdist(self, p1, p2, c):
sqrt_c = c ** 0.5
dist_c = artanh(
sqrt_c * self.mobius_add(-p1, p2, c, dim=-1).norm(dim=-1, p=2, keepdim=False)
)
dist = dist_c * 2 / sqrt_c
return dist ** 2
@staticmethod
def dist0(p1, c, keepdim=False):
sqrt_c = c ** 0.5
dist_c = artanh(
sqrt_c * p1.norm(dim=-1, p=2, keepdim=keepdim)
)
dist = dist_c * 2 / sqrt_c
return dist
def _lambda_x(self, x, c):
x_sqnorm = torch.sum(x.data.pow(2), dim=-1, keepdim=True)
return 2 / (1. - c * x_sqnorm).clamp_min(self.min_norm)
def egrad2rgrad(self, p, dp, c):
lambda_p = self._lambda_x(p, c)
dp /= lambda_p.pow(2)
return dp
def proj(self, x, c):
norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), self.min_norm)
maxnorm = (1 - self.eps[x.dtype]) / (c ** 0.5)
cond = norm > maxnorm
projected = x / norm * maxnorm
return torch.where(cond, projected, x)
def proj_tan(self, u, p, c):
return u
def proj_tan0(self, u, c):
return u
def expmap(self, u, p, c):
sqrt_c = c ** 0.5
u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
second_term = (
tanh(sqrt_c / 2 * self._lambda_x(p, c) * u_norm)
* u
/ (sqrt_c * u_norm)
)
gamma_1 = self.mobius_add(p, second_term, c)
return gamma_1
def logmap(self, p1, p2, c):
sub = self.mobius_add(-p1, p2, c)
sub_norm = sub.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
lam = self._lambda_x(p1, c)
sqrt_c = c ** 0.5
return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm
def expmap0(self, u, c):
sqrt_c = c ** 0.5
u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), self.min_norm)
gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
return gamma_1
def logmap0(self, p, c):
sqrt_c = c ** 0.5
p_norm = p.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
scale = 1. / sqrt_c * artanh(sqrt_c * p_norm) / p_norm
return scale * p
def mobius_add(self, x, y, c, dim=-1):
x2 = x.pow(2).sum(dim=dim, keepdim=True)
y2 = y.pow(2).sum(dim=dim, keepdim=True)
xy = (x * y).sum(dim=dim, keepdim=True)
num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
return num / denom.clamp_min(self.min_norm)
def mobius_matvec(self, m, x, c):
sqrt_c = c ** 0.5
x_norm = x.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
mx = x @ m.transpose(-1, -2)
mx_norm = mx.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c)
cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.bool)
res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
res = torch.where(cond, res_0, res_c)
return res
def init_weights(self, w, c, irange=1e-5):
w.data.uniform_(-irange, irange)
return w
def _gyration(self, u, v, w, c, dim: int = -1):
u2 = u.pow(2).sum(dim=dim, keepdim=True)
v2 = v.pow(2).sum(dim=dim, keepdim=True)
uv = (u * v).sum(dim=dim, keepdim=True)
uw = (u * w).sum(dim=dim, keepdim=True)
vw = (v * w).sum(dim=dim, keepdim=True)
c2 = c ** 2
a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw
b = -c2 * vw * u2 - c * uw
d = 1 + 2 * c * uv + c2 * u2 * v2
return w + 2 * (a * u + b * v) / d.clamp_min(self.min_norm)
def inner(self, x, c, u, v=None, keepdim=False):
if v is None:
v = u
lambda_x = self._lambda_x(x, c)
return lambda_x ** 2 * (u * v).sum(dim=-1, keepdim=keepdim)
def ptransp(self, x, y, u, c):
lambda_x = self._lambda_x(x, c)
lambda_y = self._lambda_x(y, c)
return self._gyration(y, -x, u, c) * lambda_x / lambda_y
def ptransp_(self, x, y, u, c):
lambda_x = self._lambda_x(x, c)
lambda_y = self._lambda_x(y, c)
return self._gyration(y, -x, u, c) * lambda_x / lambda_y
def ptransp0(self, x, u, c):
lambda_x = self._lambda_x(x, c)
return 2 * u / lambda_x.clamp_min(self.min_norm)
@staticmethod
def to_hyperboloid(x, c):
K = 1. / c
sqrtK = K ** 0.5
sqnorm = torch.norm(x, p=2, dim=1, keepdim=True) ** 2
return sqrtK * torch.cat([K + sqnorm, 2 * sqrtK * x], dim=1) / (K - sqnorm)
from .hgwavenet import HGWaveNet
from .spatial_dilated_conv import SpatialDilatedConv
from .temporal_casual_conv import TemporalCasualConv
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment