Skip to content
Snippets Groups Projects
Commit 7725415a authored by yacinetouahria's avatar yacinetouahria
Browse files

final push

parents
No related branches found
No related tags found
1 merge request!1final push
Showing
with 1543 additions and 0 deletions
import torch
from geoopt.manifolds import PoincareBall as PoincareBallParent
from geoopt.manifolds.stereographic.math import _lambda_x, arsinh, tanh
MIN_NORM = 1e-15
class PoincareBall(PoincareBallParent):
def __init__(self, dim, c=1.0):
super().__init__(c)
self.register_buffer("dim", torch.as_tensor(dim, dtype=torch.int))
def proju0(self, u):
return self.proju(self.zero.expand_as(u), u)
@property
def coord_dim(self):
return int(self.dim)
@property
def device(self):
return self.c.device
@property
def zero(self):
return torch.zeros(1, self.dim).to(self.device)
def logdetexp(self, x, y, is_vector=False, keepdim=False):
d = self.norm(x, y, keepdim=keepdim) if is_vector else self.dist(x, y, keepdim=keepdim)
d[d == 0] = 1e-15
return (self.dim - 1) * (torch.sinh(self.c.sqrt()*d) / self.c.sqrt() / d).log()
def inner(self, x, u, v=None, *, keepdim=False, dim=-1):
if v is None: v = u
return _lambda_x(x, self.c, keepdim=keepdim, dim=dim) ** 2 * (u * v).sum(
dim=dim, keepdim=keepdim
)
def expmap_polar(self, x, u, r, dim: int = -1):
sqrt_c = self.c ** 0.5
u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM)
second_term = (
tanh(sqrt_c / 2 * r)
* u
/ (sqrt_c * u_norm)
)
gamma_1 = self.mobius_add(x, second_term, dim=dim)
return gamma_1
def normdist2plane(self, x, a, p, keepdim: bool = False, signed: bool = False, dim: int = -1, norm: bool = False):
c = self.c
sqrt_c = c ** 0.5
diff = self.mobius_add(-p, x, dim=dim)
diff_norm2 = diff.pow(2).sum(dim=dim, keepdim=keepdim).clamp_min(MIN_NORM)
sc_diff_a = (diff * a).sum(dim=dim, keepdim=keepdim)
if not signed:
sc_diff_a = sc_diff_a.abs()
a_norm = a.norm(dim=dim, keepdim=keepdim, p=2).clamp_min(MIN_NORM)
num = 2 * sqrt_c * sc_diff_a
denom = (1 - c * diff_norm2) * a_norm
res = arsinh(num / denom.clamp_min(MIN_NORM)) / sqrt_c
if norm:
res = res * a_norm# * self.lambda_x(a, dim=dim, keepdim=keepdim)
return res
class PoincareBallExact(PoincareBall):
__doc__ = r"""
See Also
--------
:class:`PoincareBall`
Notes
-----
The implementation of retraction is an exact exponential map, this retraction will be used in optimization
"""
retr_transp = PoincareBall.expmap_transp
transp_follow_retr = PoincareBall.transp_follow_expmap
retr = PoincareBall.expmap
def extra_repr(self):
return "exact"
from Ghypeddings.PVAE.models.tabular import Tabular
__all__ = [Tabular]
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import prod
from Ghypeddings.PVAE.utils import Constants
from Ghypeddings.PVAE.ops.manifold_layers import GeodesicLayer, MobiusLayer, LogZero, ExpZero
from torch.nn.modules.module import Module
def get_dim_act(args):
"""
Helper function to get dimension and activation at every layer.
:param args:
:return:
"""
if not args.act:
act = lambda x: x
else:
act = getattr(F, args.act)
acts = [act] * (args.num_layers - 1)
dims = [args.feat_dim] + ([args.hidden_dim] * (args.num_layers - 1))
return dims, acts
class Encoder(nn.Module):
"""
Encoder abstract class.
"""
def __init__(self, c):
super(Encoder, self).__init__()
self.c = c
def encode(self, x, adj):
input = (x, adj)
output, _ = self.layers.forward(input)
return output
class GraphConvolution(Module):
"""
Simple GCN layer.
"""
def __init__(self, in_features, out_features, dropout, act, use_bias):
super(GraphConvolution, self).__init__()
self.dropout = dropout
self.linear = nn.Linear(in_features, out_features, use_bias)
self.act = act
self.in_features = in_features
self.out_features = out_features
def forward(self, input):
x, adj = input
hidden = self.linear.forward(x)
hidden = F.dropout(hidden, self.dropout, training=self.training)
if adj.is_sparse:
support = torch.spmm(adj, hidden)
else:
support = torch.mm(adj, hidden)
output = self.act(support), adj
return output
def extra_repr(self):
return 'input_dim={}, output_dim={}'.format(
self.in_features, self.out_features
)
class GCN(Encoder):
"""
Graph Convolution Networks.
"""
def __init__(self, c, args):
super(GCN, self).__init__(c)
assert args.num_layers > 0
dims, acts = get_dim_act(args)
gc_layers = []
for i in range(len(dims) - 1):
in_dim, out_dim = dims[i], dims[i + 1]
act = acts[i]
gc_layers.append(GraphConvolution(in_dim, out_dim, args.dropout, act, args.bias))
self.layers = nn.Sequential(*gc_layers)
def extra_hidden_layer(hidden_dim, non_lin):
return nn.Sequential(nn.Linear(hidden_dim, hidden_dim), non_lin)
class EncWrapped(nn.Module):
""" Usual encoder followed by an exponential map """
def __init__(self,c,args, manifold, data_size, non_lin, num_hidden_layers, hidden_dim, prior_iso):
super(EncWrapped, self).__init__()
self.manifold = manifold
self.data_size = data_size
self.enc = GCN(c,args)
self.fc21 = nn.Linear(hidden_dim, manifold.coord_dim)
self.fc22 = nn.Linear(hidden_dim, manifold.coord_dim if not prior_iso else 1)
def forward(self,adj,x):
e = self.enc.encode(x,adj)
mu = self.fc21(e) # flatten data
mu = self.manifold.expmap0(mu)
return mu, F.softplus(self.fc22(e)) + Constants.eta, self.manifold
class DecWrapped(nn.Module):
""" Usual encoder preceded by a logarithm map """
def __init__(self, manifold, data_size, non_lin, num_hidden_layers, hidden_dim):
super(DecWrapped, self).__init__()
self.data_size = data_size
self.manifold = manifold
modules = []
modules.append(nn.Sequential(nn.Linear(manifold.coord_dim, hidden_dim), non_lin))
modules.extend([extra_hidden_layer(hidden_dim, non_lin) for _ in range(num_hidden_layers - 1)])
self.dec = nn.Sequential(*modules)
# self.fc31 = nn.Linear(hidden_dim, prod(data_size))
self.fc31 = nn.Linear(hidden_dim, data_size[1])
def forward(self, z):
z = self.manifold.logmap0(z)
d = self.dec(z)
# mu = self.fc31(d).view(*z.size()[:-1], *self.data_size) # reshape data
mu = self.fc31(d).view(*z.size()[:-1], 1, self.data_size[1])
return mu, torch.ones_like(mu)
class DecGeo(nn.Module):
""" First layer is a Hypergyroplane followed by usual decoder """
def __init__(self, manifold, data_size, non_lin, num_hidden_layers, hidden_dim):
super(DecGeo, self).__init__()
self.data_size = data_size
modules = []
modules.append(nn.Sequential(GeodesicLayer(manifold.coord_dim, hidden_dim, manifold), non_lin))
modules.extend([extra_hidden_layer(hidden_dim, non_lin) for _ in range(num_hidden_layers - 1)])
self.dec = nn.Sequential(*modules)
self.fc31 = nn.Linear(hidden_dim, data_size[1])
def forward(self, z):
d = self.dec(z)
# mu = self.fc31(d).view(*z.size()[:-1], *self.data_size) # reshape data
mu = self.fc31(d).view(*z.size()[:-1], 1, self.data_size[1])
return mu, torch.ones_like(mu)
class EncMob(nn.Module):
""" Last layer is a Mobius layers """
def __init__(self,c,args, manifold, data_size, non_lin, num_hidden_layers, hidden_dim, prior_iso):
super(EncMob, self).__init__()
self.manifold = manifold
self.data_size = data_size
# modules = []
# modules.append(nn.Sequential(nn.Linear(data_size[1], hidden_dim), non_lin))
# modules.extend([extra_hidden_layer(hidden_dim, non_lin) for _ in range(num_hidden_layers - 1)])
# self.enc = nn.Sequential(*modules)
self.enc = GCN(c,args)
self.fc21 = MobiusLayer(hidden_dim, manifold.coord_dim, manifold)
self.fc22 = nn.Linear(hidden_dim, manifold.coord_dim if not prior_iso else 1)
def forward(self,adj,x):
#e = self.enc(x.view(*x.size()[:-len(self.data_size)], -1)) # flatten data
e = self.enc.encode(x,adj)
mu = self.fc21(e) # flatten data
mu = self.manifold.expmap0(mu)
return mu, F.softplus(self.fc22(e)) + Constants.eta, self.manifold
class DecMob(nn.Module):
""" First layer is a Mobius Matrix multiplication """
def __init__(self, manifold, data_size, non_lin, num_hidden_layers, hidden_dim):
super(DecMob, self).__init__()
self.data_size = data_size
modules = []
modules.append(nn.Sequential(MobiusLayer(manifold.coord_dim, hidden_dim, manifold), LogZero(manifold), non_lin))
modules.extend([extra_hidden_layer(hidden_dim, non_lin) for _ in range(num_hidden_layers - 1)])
self.dec = nn.Sequential(*modules)
self.fc31 = nn.Linear(hidden_dim, prod(data_size))
def forward(self, z):
d = self.dec(z)
mu = self.fc31(d).view(*z.size()[:-1], *self.data_size) # reshape data
return mu, torch.ones_like(mu)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
from torch.utils.data import DataLoader
import math
from Ghypeddings.PVAE.models.vae import VAE
from Ghypeddings.PVAE.distributions import RiemannianNormal, WrappedNormal
from torch.distributions import Normal
import Ghypeddings.PVAE.manifolds as manifolds
from Ghypeddings.PVAE.models.architectures import EncWrapped, DecWrapped, EncMob, DecMob, DecGeo
from Ghypeddings.PVAE.utils import get_activation
class Tabular(VAE):
""" Derive a specific sub-class of a VAE for tabular data. """
def __init__(self, params):
c = nn.Parameter(params.c * torch.ones(1), requires_grad=False)
manifold = getattr(manifolds, 'PoincareBall')(params.dim, c)
super(Tabular, self).__init__(
eval(params.prior), # prior distribution
eval(params.posterior), # posterior distribution
dist.Normal, # likelihood distribution
eval('Enc' + params.enc)(params.c,params,manifold, params.data_size, get_activation(params), params.num_layers, params.hidden_dim, params.prior_iso),
eval('Dec' + params.dec)(manifold, params.data_size, get_activation(params), params.num_layers, params.hidden_dim),
params
)
self.manifold = manifold
self._pz_mu = nn.Parameter(torch.zeros(1, params.dim), requires_grad=False)
self._pz_logvar = nn.Parameter(torch.zeros(1, 1), requires_grad=params.learn_prior_std)
self.modelName = 'Tabular'
@property
def pz_params(self):
return self._pz_mu.mul(1), F.softplus(self._pz_logvar).div(math.log(2)).mul(self.prior_std), self.manifold
\ No newline at end of file
# Base VAE class definition
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
from Ghypeddings.PVAE.utils import get_mean_param
class VAE(nn.Module):
def __init__(self, prior_dist, posterior_dist, likelihood_dist, enc, dec, params):
super(VAE, self).__init__()
self.pz = prior_dist
self.px_z = likelihood_dist
self.qz_x = posterior_dist
self.enc = enc
self.dec = dec
self.modelName = None
self.params = params
self.data_size = params.data_size
self.prior_std = params.prior_std
if self.px_z == dist.RelaxedBernoulli:
self.px_z.log_prob = lambda self, value: \
-F.binary_cross_entropy_with_logits(
self.probs if value.dim() <= self.probs.dim() else self.probs.expand_as(value),
value.expand(self.batch_shape) if value.dim() <= self.probs.dim() else value,
reduction='none'
)
def generate(self, N, K):
self.eval()
with torch.no_grad():
mean_pz = get_mean_param(self.pz_params)
mean = get_mean_param(self.dec(mean_pz))
px_z_params = self.dec(self.pz(*self.pz_params).sample(torch.Size([N])))
means = get_mean_param(px_z_params)
samples = self.px_z(*px_z_params).sample(torch.Size([K]))
return mean, \
means.view(-1, *means.size()[2:]), \
samples.view(-1, *samples.size()[3:])
def reconstruct(self, data , edge_index):
self.eval()
with torch.no_grad():
qz_x = self.qz_x(*self.enc(edge_index,data))
px_z_params = self.dec(qz_x.rsample(torch.Size([1])).squeeze(0))
return get_mean_param(px_z_params)
def forward(self, x , edge_index, K=1):
embeddings = self.enc(edge_index,x)
qz_x = self.qz_x(*embeddings)
zs = qz_x.rsample(torch.Size([K]))
px_z = self.px_z(*self.dec(zs))
return qz_x, px_z, zs , embeddings
@property
def pz_params(self):
return self._pz_mu.mul(1), F.softplus(self._pz_logvar).div(math.log(2)).mul(self.prior_std_scale)
def init_last_layer_bias(self, dataset): pass
import torch
import torch.distributions as dist
from numpy import prod
from Ghypeddings.PVAE.utils import has_analytic_kl, log_mean_exp
import torch.nn.functional as F
def vae_objective(model, idx, x , graph, K=1, beta=1.0, components=False, analytical_kl=False, **kwargs):
"""Computes E_{p(x)}[ELBO] """
qz_x, px_z, zs , embeddings = model(x, graph,K)
_, B, D = zs.size()
flat_rest = torch.Size([*px_z.batch_shape[:2], -1])
x = x.unsqueeze(0).unsqueeze(2)
lpx_z = px_z.log_prob(x.expand(px_z.batch_shape)).view(flat_rest).sum(-1)
pz = model.pz(*model.pz_params)
kld = dist.kl_divergence(qz_x, pz).unsqueeze(0).sum(-1) if \
has_analytic_kl(type(qz_x), model.pz) and analytical_kl else \
qz_x.log_prob(zs).sum(-1) - pz.log_prob(zs).sum(-1)
lpx_z_selected = lpx_z[:, idx]
kld_selected = kld[:, idx]
obj = -lpx_z_selected.mean(0).sum() + beta * kld_selected.mean(0).sum()
return (qz_x, px_z, lpx_z_selected, kld_selected, obj , embeddings) if components else obj
def _iwae_objective_vec(model, x, K):
"""Helper for IWAE estimate for log p_\theta(x) -- full vectorisation."""
qz_x, px_z, zs = model(x, K)
flat_rest = torch.Size([*px_z.batch_shape[:2], -1])
lpz = model.pz(*model.pz_params).log_prob(zs).sum(-1)
lpx_z = px_z.log_prob(x.expand(zs.size(0), *x.size())).view(flat_rest).sum(-1)
lqz_x = qz_x.log_prob(zs).sum(-1)
obj = lpz.squeeze(-1) + lpx_z.view(lpz.squeeze(-1).shape) - lqz_x.squeeze(-1)
return -log_mean_exp(obj).sum()
def iwae_objective(model, x, K):
"""Computes an importance-weighted ELBO estimate for log p_\theta(x)
Iterates over the batch as necessary.
Appropriate negation (for minimisation) happens in the helper
"""
split_size = int(x.size(0) / (K * prod(x.size()) / (3e7))) # rough heuristic
if split_size >= x.size(0):
obj = _iwae_objective_vec(model, x, K)
else:
obj = 0
for bx in x.split(split_size):
obj = obj + _iwae_objective_vec(model, bx, K)
return obj
import math
import torch
from torch import nn
from torch.nn.parameter import Parameter
from torch.nn import init
from Ghypeddings.PVAE.manifolds import PoincareBall, Euclidean
from geoopt import ManifoldParameter
class RiemannianLayer(nn.Module):
def __init__(self, in_features, out_features, manifold, over_param, weight_norm):
super(RiemannianLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.manifold = manifold
self._weight = Parameter(torch.Tensor(out_features, in_features))
self.over_param = over_param
self.weight_norm = weight_norm
if self.over_param:
self._bias = ManifoldParameter(torch.Tensor(out_features, in_features), manifold=manifold)
else:
self._bias = Parameter(torch.Tensor(out_features, 1))
self.reset_parameters()
@property
def weight(self):
return self.manifold.transp0(self.bias, self._weight) # weight \in T_0 => weight \in T_bias
@property
def bias(self):
if self.over_param:
return self._bias
else:
return self.manifold.expmap0(self._weight * self._bias) # reparameterisation of a point on the manifold
def reset_parameters(self):
init.kaiming_normal_(self._weight, a=math.sqrt(5))
fan_in, _ = init._calculate_fan_in_and_fan_out(self._weight)
bound = 4 / math.sqrt(fan_in)
init.uniform_(self._bias, -bound, bound)
if self.over_param:
with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias))
class GeodesicLayer(RiemannianLayer):
def __init__(self, in_features, out_features, manifold, over_param=False, weight_norm=False):
super(GeodesicLayer, self).__init__(in_features, out_features, manifold, over_param, weight_norm)
def forward(self, input):
input = input.unsqueeze(-2).expand(*input.shape[:-(len(input.shape) - 2)], self.out_features, self.in_features)
res = self.manifold.normdist2plane(input, self.bias, self.weight,
signed=True, norm=self.weight_norm)
return res
class Linear(nn.Linear):
def __init__(self, in_features, out_features, **kwargs):
super(Linear, self).__init__(
in_features,
out_features,
)
class MobiusLayer(RiemannianLayer):
def __init__(self, in_features, out_features, manifold, over_param=False, weight_norm=False):
super(MobiusLayer, self).__init__(in_features, out_features, manifold, over_param, weight_norm)
def forward(self, input):
res = self.manifold.mobius_matvec(self.weight, input)
return res
class ExpZero(nn.Module):
def __init__(self, manifold):
super(ExpZero, self).__init__()
self.manifold = manifold
def forward(self, input):
return self.manifold.expmap0(input)
class LogZero(nn.Module):
def __init__(self, manifold):
super(LogZero, self).__init__()
self.manifold = manifold
def forward(self, input):
return self.manifold.logmap0(input)
import sys
sys.path.append(".")
sys.path.append("..")
import os
import datetime
from collections import defaultdict
import torch
from torch import optim
import numpy as np
import logging
import time
from Ghypeddings.PVAE.utils import probe_infnan , process_data , create_args , get_classifier,get_clustering_algorithm,get_anomaly_detection_algorithm
import Ghypeddings.PVAE.objectives as objectives
from Ghypeddings.PVAE.models import Tabular
from Ghypeddings.classifiers import calculate_metrics
runId = datetime.datetime.now().isoformat().replace(':','_')
torch.backends.cudnn.benchmark = True
class PVAE:
def __init__(self,
adj,
features,
labels,
dim,
hidden_dim,
num_layers=2,
c=1.0,
act='relu',
lr=0.01,
cuda=0,
epochs=50,
seed=42,
eval_freq=1,
val_prop=0.,
test_prop=0.3,
dropout=0.1,
beta1=0.9,
beta2=.999,
K=1,
beta=.2,
analytical_kl=True,
posterior='WrappedNormal',
prior='WrappedNormal',
prior_iso=True,
prior_std=1.,
learn_prior_std=True,
enc='Mob',
dec='Geo',
bias=True,
alpha=0.5,
classifier=None,
clusterer=None,
log_freq=1,
normalize_adj=False,
normalize_feats=True,
anomaly_detector=None
):
self.args = create_args(dim,hidden_dim,num_layers,c,act,lr,cuda,epochs,seed,eval_freq,val_prop,test_prop,dropout,beta1,beta2,K,beta,analytical_kl,posterior,prior,prior_iso,prior_std,learn_prior_std,enc,dec,bias,alpha,classifier,clusterer,log_freq,normalize_adj,normalize_feats,anomaly_detector)
self.args.n_classes = len(np.unique(labels))
self.args.feat_dim = features.shape[1]
self.data = process_data(self.args,adj,features,labels)
self.args.data_size = [adj.shape[0],self.args.feat_dim]
self.args.batch_size=1
self.cls = None
if int(self.args.cuda) >= 0:
torch.cuda.manual_seed(self.args.seed)
self.args.device = 'cuda:' + str(self.args.cuda) if int(self.args.cuda) >= 0 else 'cpu'
else:
self.args.device = 'cpu'
self.args.prior_iso = self.args.prior_iso or self.args.posterior == 'RiemannianNormal'
# Choosing and saving a random seed for reproducibility
if self.args.seed == 0: self.args.seed = int(torch.randint(0, 2**32 - 1, (1,)).item())
torch.manual_seed(self.args.seed)
np.random.seed(self.args.seed)
torch.cuda.manual_seed_all(self.args.seed)
torch.manual_seed(self.args.seed)
torch.backends.cudnn.deterministic = True
self.model = Tabular(self.args).to(self.args.device)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.lr, amsgrad=True, betas=(self.args.beta1, self.args.beta2))
self.loss_function = getattr(objectives,'vae_objective')
if self.args.cuda is not None and int(self.args.cuda) >= 0 :
os.environ['CUDA_VISIBLE_DEVICES'] = str(self.args.cuda)
self.model = self.model.to(self.args.device)
for x, val in self.data.items():
if torch.is_tensor(self.data[x]):
self.data[x] = self.data[x].to(self.args.device)
self.tb_embeddings = None
def fit(self):
tot_params = sum([np.prod(p.size()) for p in self.model.parameters()])
logging.info(f"Total number of parameters: {tot_params}")
t_total = time.time()
agg = defaultdict(list)
b_loss, b_recon, b_kl , b_mlik , tb_loss = sys.float_info.max, sys.float_info.max ,sys.float_info.max,sys.float_info.max,sys.float_info.max
best_losses = []
train_losses = []
val_losses = []
for epoch in range(self.args.epochs):
self.model.train()
self.optimizer.zero_grad()
qz_x, px_z, lik, kl, loss , embeddings = self.loss_function(self.model,self.data['idx_train'], self.data['features'], self.data['adj_train'], K=self.args.K, beta=self.args.beta, components=True, analytical_kl=self.args.analytical_kl)
probe_infnan(loss, "Training loss:")
loss.backward()
self.optimizer.step()
t_loss = loss.item() / len(self.data['idx_train'])
t_recon = -lik.mean(0).sum().item() / len(self.data['idx_train'])
t_kl = kl.sum(-1).mean(0).sum().item() / len(self.data['idx_train'])
if(t_loss < b_loss):
b_loss = t_loss
b_recon = t_recon
b_kl = t_kl
agg['train_loss'].append(t_loss )
agg['train_recon'].append(t_recon )
agg['train_kl'].append(t_kl )
train_losses.append(t_recon)
if(len(best_losses) == 0):
best_losses.append(train_losses[0])
elif (best_losses[-1] > train_losses[-1]):
best_losses.append(train_losses[-1])
else:
best_losses.append(best_losses[-1])
if (epoch + 1) % self.args.log_freq == 0:
print('====> Epoch: {:03d} Loss: {:.2f} Recon: {:.2f} KL: {:.2f}'.format(epoch, agg['train_loss'][-1], agg['train_recon'][-1], agg['train_kl'][-1]))
if (epoch + 1) % self.args.eval_freq == 0 and self.args.val_prop:
self.model.eval()
with torch.no_grad():
qz_x, px_z, lik, kl, loss , embeddings= self.loss_function(self.model,self.data['idx_val'], self.data['features'],self.data['adj_train'], K=self.args.K, beta=self.args.beta, components=True)
tt_loss = loss.item() / len(self.data['idx_val'])
val_losses.append(tt_loss)
if(tt_loss < tb_loss):
tb_loss = tt_loss
self.tb_embeddings = embeddings[0]
agg['test_loss'].append(tt_loss )
print('====> Test loss: {:.4f}'.format(agg['test_loss'][-1]))
logging.info("Optimization Finished!")
logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))
print('====> Training: Best Loss: {:.2f} Best Recon: {:.2f} Best KL: {:.2f}'.format(b_loss,b_recon,b_kl))
print('====> Testing: Best Loss: {:.2f}'.format(tb_loss))
train_idx = self.data['idx_train']
val_idx = self.data['idx_val']
idx = np.unique(np.concatenate((train_idx,val_idx)))
X = self.model.manifold.logmap0(self.tb_embeddings[idx]).cpu().detach().numpy()
y = self.data['labels'].cpu().reshape(-1,1)[idx]
if(self.args.classifier):
self.cls = get_classifier(self.args, X,y)
acc,f1,recall,precision,roc_auc = calculate_metrics(self.cls,X,y)
elif self.args.clusterer:
y = y.reshape(-1,)
acc,f1,recall,precision,roc_auc = get_clustering_algorithm(self.args.clusterer,X,y)[6:]
elif self.args.anomaly_detector:
y = y.reshape(-1,)
acc,f1,recall,precision,roc_auc = get_anomaly_detection_algorithm(self.args.anomaly_detector,X,y)[6:]
return {'train':train_losses,'best':best_losses,'val':val_losses},acc,f1,recall,precision,roc_auc,time.time() - t_total
def predict(self):
self.model.eval()
with torch.no_grad():
qz_x, px_z, lik, kl, loss , embeddings=self.loss_function(self.model,self.data['idx_test'], self.data['features'],self.data['adj_train'], K=self.args.K, beta=self.args.beta, components=True)
tt_loss = loss.item() / len(self.data['idx_test'])
test_idx = self.data['idx_test']
data = self.model.manifold.logmap0(embeddings[0][test_idx]).cpu().detach().numpy()
labels = self.data['labels'].reshape(-1,1).cpu()[test_idx]
if self.args.classifier:
acc,f1,recall,precision,roc_auc = calculate_metrics(self.cls,data,labels)
elif self.args.clusterer:
labels = labels.reshape(-1,)
acc,f1,recall,precision,roc_auc = get_clustering_algorithm(self.args.clusterer,data,labels)[6:]
elif self.args.anomaly_detector:
labels = labels.reshape(-1,)
acc,f1,recall,precision,roc_auc = get_anomaly_detection_algorithm(self.args.anomaly_detector,data,labels)[6:]
self.tb_embeddings = embeddings[0]
return abs(tt_loss) , acc, f1 , recall,precision,roc_auc
def save_embeddings(self,directory):
tb_embeddings_euc = self.model.manifold.logmap0(self.tb_embeddings)
for_classification_hyp = np.hstack((self.tb_embeddings.cpu().detach().numpy(),self.data['labels'].reshape(-1,1).cpu()))
for_classification_euc = np.hstack((tb_embeddings_euc.cpu().detach().numpy(),self.data['labels'].reshape(-1,1).cpu()))
hyp_file_path = os.path.join(directory,'pvae_embeddings_hyp.csv')
euc_file_path = os.path.join(directory,'pvae_embeddings_euc.csv')
np.savetxt(hyp_file_path, for_classification_hyp, delimiter=',')
np.savetxt(euc_file_path, for_classification_euc, delimiter=',')
import sys
import math
import time
import os
import shutil
import torch
import torch.distributions as dist
from torch.autograd import Variable, Function, grad
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
import numpy as np
import argparse
import torch.nn as nn
import scipy.sparse as sp
def lexpand(A, *dimensions):
"""Expand tensor, adding new dimensions on left."""
return A.expand(tuple(dimensions) + A.shape)
def rexpand(A, *dimensions):
"""Expand tensor, adding new dimensions on right."""
return A.view(A.shape + (1,)*len(dimensions)).expand(A.shape + tuple(dimensions))
def assert_no_nan(name, g):
if torch.isnan(g).any(): raise Exception('nans in {}'.format(name))
def assert_no_grad_nan(name, x):
if x.requires_grad: x.register_hook(lambda g: assert_no_nan(name, g))
# Classes
class Constants(object):
eta = 1e-5
log2 = math.log(2)
logpi = math.log(math.pi)
log2pi = math.log(2 * math.pi)
logceilc = 88 # largest cuda v s.t. exp(v) < inf
logfloorc = -104 # smallest cuda v s.t. exp(v) > 0
invsqrt2pi = 1. / math.sqrt(2 * math.pi)
sqrthalfpi = math.sqrt(math.pi/2)
def logsinh(x):
# torch.log(sinh(x))
return x + torch.log(1 - torch.exp(-2 * x)) - Constants.log2
def logcosh(x):
# torch.log(cosh(x))
return x + torch.log(1 + torch.exp(-2 * x)) - Constants.log2
class Arccosh(Function):
# https://github.com/facebookresearch/poincare-embeddings/blob/master/model.py
@staticmethod
def forward(ctx, x):
ctx.z = torch.sqrt(x * x - 1)
return torch.log(x + ctx.z)
@staticmethod
def backward(ctx, g):
z = torch.clamp(ctx.z, min=Constants.eta)
z = g / z
return z
class Arcsinh(Function):
@staticmethod
def forward(ctx, x):
ctx.z = torch.sqrt(x * x + 1)
return torch.log(x + ctx.z)
@staticmethod
def backward(ctx, g):
z = torch.clamp(ctx.z, min=Constants.eta)
z = g / z
return z
# https://stackoverflow.com/questions/14906764/how-to-redirect-stdout-to-both-file-and-console-with-scripting
class Logger(object):
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
class Timer:
def __init__(self, name):
self.name = name
def __enter__(self):
self.begin = time.time()
return self
def __exit__(self, *args):
self.end = time.time()
self.elapsed = self.end - self.begin
self.elapsedH = time.gmtime(self.elapsed)
print('====> [{}] Time: {:7.3f}s or {}'
.format(self.name,
self.elapsed,
time.strftime("%H:%M:%S", self.elapsedH)))
# Functions
def save_vars(vs, filepath):
"""
Saves variables to the given filepath in a safe manner.
"""
if os.path.exists(filepath):
shutil.copyfile(filepath, '{}.old'.format(filepath))
torch.save(vs, filepath)
def save_model(model, filepath):
"""
To load a saved model, simply use
`model.load_state_dict(torch.load('path-to-saved-model'))`.
"""
save_vars(model.state_dict(), filepath)
def log_mean_exp(value, dim=0, keepdim=False):
return log_sum_exp(value, dim, keepdim) - math.log(value.size(dim))
def log_sum_exp(value, dim=0, keepdim=False):
m, _ = torch.max(value, dim=dim, keepdim=True)
value0 = value - m
if keepdim is False:
m = m.squeeze(dim)
return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
def log_sum_exp_signs(value, signs, dim=0, keepdim=False):
m, _ = torch.max(value, dim=dim, keepdim=True)
value0 = value - m
if keepdim is False:
m = m.squeeze(dim)
return m + torch.log(torch.sum(signs * torch.exp(value0), dim=dim, keepdim=keepdim))
def get_mean_param(params):
"""Return the parameter used to show reconstructions or generations.
For example, the mean for Normal, or probs for Bernoulli.
For Bernoulli, skip first parameter, as that's (scalar) temperature
"""
if params[0].dim() == 0:
return params[1]
# elif len(params) == 3:
# return params[1]
else:
return params[0]
def probe_infnan(v, name, extras={}):
nps = torch.isnan(v)
s = nps.sum().item()
if s > 0:
print('>>> {} >>>'.format(name))
print(name, s)
print(v[nps])
for k, val in extras.items():
print(k, val, val.sum().item())
quit()
def has_analytic_kl(type_p, type_q):
return (type_p, type_q) in torch.distributions.kl._KL_REGISTRY
def split_data(labels, test_prop,val_prop):
nb_nodes = labels.shape[0]
all_idx = np.arange(nb_nodes)
pos_idx = labels.nonzero()[0]
neg_idx = (1. - labels).nonzero()[0]
np.random.shuffle(pos_idx)
np.random.shuffle(neg_idx)
pos_idx = pos_idx.tolist()
neg_idx = neg_idx.tolist()
nb_pos_neg = min(len(pos_idx), len(neg_idx))
nb_val = round(val_prop * nb_pos_neg)
nb_test = round(test_prop * nb_pos_neg)
idx_val_pos, idx_test_pos, idx_train_pos = pos_idx[:nb_val], pos_idx[nb_val:nb_val + nb_test], pos_idx[
nb_val + nb_test:]
idx_val_neg, idx_test_neg, idx_train_neg = neg_idx[:nb_val], neg_idx[nb_val:nb_val + nb_test], neg_idx[
nb_val + nb_test:]
return idx_test_pos + idx_test_neg, idx_train_pos + idx_train_neg, idx_val_pos + idx_val_neg,
def process_data(args, adj,features,labels):
data = process_data_nc(args,adj,features,labels)
data['adj_train'], data['features'] = process(
data['adj_train'], data['features'],args.normalize_adj,args.normalize_feats
)
return data
def process_data_nc(args,adj,features,labels):
idx_test, idx_train , idx_val= split_data(labels, args.test_prop,args.val_prop)
labels = torch.LongTensor(labels)
data = {'adj_train': sp.csr_matrix(adj), 'features': features, 'labels': labels, 'idx_train': idx_train, 'idx_test': idx_test , 'idx_val':idx_val}
return data
def process(adj, features, normalize_adj, normalize_feats):
if sp.isspmatrix(features):
features = np.array(features.todense())
if normalize_feats:
features = normalize(features)
features = torch.Tensor(features)
if normalize_adj:
adj = normalize(adj)
adj = sparse_mx_to_torch_sparse_tensor(adj)
return adj, features
def normalize(mx):
"""Row-normalize sparse matrix."""
rowsum = np.array(mx.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
mx = r_mat_inv.dot(mx)
return mx
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
"""Convert a scipy sparse matrix to a torch sparse tensor."""
sparse_mx = sparse_mx.tocoo()
indices = torch.from_numpy(
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
)
values = torch.Tensor(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse.FloatTensor(indices, values, shape)
def create_args(*args):
parser = argparse.ArgumentParser()
parser.add_argument('--dim', type=int, default=args[0])
parser.add_argument('--hidden_dim', type=int, default=args[1])
parser.add_argument('--num_layers', type=int, default=args[2])
parser.add_argument('--c', type=int, default=args[3])
parser.add_argument('--act', type=str, default=args[4])
parser.add_argument('--lr', type=float, default=args[5])
parser.add_argument('--cuda', type=int, default=args[6])
parser.add_argument('--epochs', type=int, default=args[7])
parser.add_argument('--seed', type=int, default=args[8])
parser.add_argument('--eval_freq', type=int, default=args[9])
parser.add_argument('--val_prop', type=float, default=args[10])
parser.add_argument('--test_prop', type=float, default=args[11])
parser.add_argument('--dropout', type=float, default=args[12])
parser.add_argument('--beta1', type=float, default=args[13])
parser.add_argument('--beta2', type=float, default=args[14])
parser.add_argument('--K', type=int, default=args[15])
parser.add_argument('--beta', type=float, default=args[16])
parser.add_argument('--analytical_kl', type=bool, default=args[17])
parser.add_argument('--posterior', type=str, default=args[18])
parser.add_argument('--prior', type=str, default=args[19])
parser.add_argument('--prior_iso', type=bool, default=args[20])
parser.add_argument('--prior_std', type=float, default=args[21])
parser.add_argument('--learn_prior_std', type=bool, default=args[22])
parser.add_argument('--enc', type=str, default=args[23])
parser.add_argument('--dec', type=str, default=args[24])
parser.add_argument('--bias', type=bool, default=args[25])
parser.add_argument('--alpha', type=float, default=args[26])
parser.add_argument('--classifier', type=str, default=args[27])
parser.add_argument('--clusterer', type=str, default=args[28])
parser.add_argument('--log_freq', type=int, default=args[29])
parser.add_argument('--normalize_adj', type=bool, default=args[30])
parser.add_argument('--normalize_feats', type=bool, default=args[31])
parser.add_argument('--anomaly_detector', type=str, default=args[32])
flags, unknown = parser.parse_known_args()
return flags
def get_activation(args):
if args.act == 'leaky_relu':
return nn.LeakyReLU(args.alpha)
elif args.act == 'rrelu':
return nn.RReLU()
elif args.act == 'relu':
return nn.ReLU()
elif args.act == 'elu':
return nn.ELU()
elif args.act == 'prelu':
return nn.PReLU()
elif args.act == 'selu':
return nn.SELU()
from Ghypeddings.classifiers import *
def get_classifier(args,X,y):
if(args.classifier):
if(args.classifier == 'svm'):
return SVM(X,y)
elif(args.classifier == 'mlp'):
return mlp(X,y,1,10,seed=args.seed)
elif(args.classifier == 'decision tree'):
return decision_tree(X,y)
elif(args.classifier == 'random forest'):
return random_forest(X,y,args.seed)
elif(args.classifier == 'adaboost'):
return adaboost(X,y,args.seed)
elif(args.classifier == 'knn'):
return KNN(X,y)
elif(args.classifier == 'naive bayes'):
return naive_bayes(X,y)
else:
raise NotImplementedError
from Ghypeddings.clusterers import *
def get_clustering_algorithm(clusterer,X,y):
if(clusterer == 'agglomerative_clustering'):
return agglomerative_clustering(X,y)
elif(clusterer == 'dbscan'):
return dbscan(X,y)
elif(clusterer == 'fuzzy_c_mean'):
return fuzzy_c_mean(X,y)
elif(clusterer == 'gaussian_mixture'):
return gaussian_mixture(X,y)
elif(clusterer == 'kmeans'):
return kmeans(X,y)
elif(clusterer == 'mean_shift'):
return mean_shift(X,y)
else:
raise NotImplementedError
from Ghypeddings.anomaly_detection import *
def get_anomaly_detection_algorithm(algorithm,X,y):
if(algorithm == 'isolation_forest'):
return isolation_forest(X,y)
elif(algorithm == 'one_class_svm'):
return one_class_svm(X,y)
elif(algorithm == 'dbscan'):
return dbscan(X,y)
elif(algorithm == 'kmeans'):
return kmeans(X,y,n_clusters=2)
elif(algorithm == 'local_outlier_factor'):
return local_outlier_factor(X,y)
else:
raise NotImplementedError
\ No newline at end of file
from __future__ import print_function
from __future__ import division
"""Euclidean layers."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
def get_dim_act(args):
"""
Helper function to get dimension and activation at every layer.
:param args:
:return:
"""
if not args.act:
act = lambda x: x
else:
act = getattr(F, args.act)
acts = [act] * (args.num_layers - 1)
dims = [args.feat_dim] + ([args.dim] * (args.num_layers - 1))
if args.task in ['lp', 'rec']:
dims += [args.dim]
acts += [act]
return dims, acts
class Linear(Module):
"""
Simple Linear layer with dropout.
"""
def __init__(self, in_features, out_features, dropout, act, use_bias):
super(Linear, self).__init__()
self.dropout = dropout
self.linear = nn.Linear(in_features, out_features, use_bias)
self.act = act
def forward(self, x):
hidden = self.linear.forward(x)
hidden = F.dropout(hidden, self.dropout, training=self.training)
out = self.act(hidden)
return out
from Ghypeddings.Poincare.manifolds.base import ManifoldParameter
from Ghypeddings.Poincare.manifolds.poincare import PoincareBall
from Ghypeddings.Poincare.manifolds.euclidean import Euclidean
\ No newline at end of file
"""Base manifold."""
from torch.nn import Parameter
class Manifold(object):
"""
Abstract class to define operations on a manifold.
"""
def __init__(self):
super().__init__()
self.eps = 10e-8
def sqdist(self, p1, p2, c):
"""Squared distance between pairs of points."""
raise NotImplementedError
def egrad2rgrad(self, p, dp, c):
"""Converts Euclidean Gradient to Riemannian Gradients."""
raise NotImplementedError
def proj(self, p, c):
"""Projects point p on the manifold."""
raise NotImplementedError
def proj_tan(self, u, p, c):
"""Projects u on the tangent space of p."""
raise NotImplementedError
def proj_tan0(self, u, c):
"""Projects u on the tangent space of the origin."""
raise NotImplementedError
def expmap(self, u, p, c):
"""Exponential map of u at point p."""
raise NotImplementedError
def logmap(self, p1, p2, c):
"""Logarithmic map of point p1 at point p2."""
raise NotImplementedError
def expmap0(self, u, c):
"""Exponential map of u at the origin."""
raise NotImplementedError
def logmap0(self, p, c):
"""Logarithmic map of point p at the origin."""
raise NotImplementedError
def mobius_add(self, x, y, c, dim=-1):
"""Adds points x and y."""
raise NotImplementedError
def mobius_matvec(self, m, x, c):
"""Performs hyperboic martrix-vector multiplication."""
raise NotImplementedError
def init_weights(self, w, c, irange=1e-5):
"""Initializes random weigths on the manifold."""
raise NotImplementedError
def inner(self, p, c, u, v=None, keepdim=False):
"""Inner product for tangent vectors at point x."""
raise NotImplementedError
def ptransp(self, x, y, u, c):
"""Parallel transport of u from x to y."""
raise NotImplementedError
def ptransp0(self, x, u, c):
"""Parallel transport of u from the origin to y."""
raise NotImplementedError
class ManifoldParameter(Parameter):
"""
Subclass of torch.nn.Parameter for Riemannian optimization.
"""
def __new__(cls, data, requires_grad, manifold, c):
return Parameter.__new__(cls, data, requires_grad)
def __init__(self, data, requires_grad, manifold, c):
self.c = c
self.manifold = manifold
def __repr__(self):
return '{} Parameter containing:\n'.format(self.manifold.name) + super(Parameter, self).__repr__()
"""Euclidean manifold."""
from Ghypeddings.Poincare.manifolds.base import Manifold
class Euclidean(Manifold):
"""
Euclidean Manifold class.
"""
def __init__(self):
super(Euclidean, self).__init__()
self.name = 'Euclidean'
def normalize(self, p):
dim = p.size(-1)
p.view(-1, dim).renorm_(2, 0, 1.)
return p
def sqdist(self, p1, p2, c):
return (p1 - p2).pow(2).sum(dim=-1)
def egrad2rgrad(self, p, dp, c):
return dp
def proj(self, p, c):
return p
def proj_tan(self, u, p, c):
return u
def proj_tan0(self, u, c):
return u
def expmap(self, u, p, c):
return p + u
def logmap(self, p1, p2, c):
return p2 - p1
def expmap0(self, u, c):
return u
def logmap0(self, p, c):
return p
def mobius_add(self, x, y, c, dim=-1):
return x + y
def mobius_matvec(self, m, x, c):
mx = x @ m.transpose(-1, -2)
return mx
def init_weights(self, w, c, irange=1e-5):
w.data.uniform_(-irange, irange)
return w
def inner(self, p, c, u, v=None, keepdim=False):
if v is None:
v = u
return (u * v).sum(dim=-1, keepdim=keepdim)
def ptransp(self, x, y, v, c):
return v
def ptransp0(self, x, v, c):
return x + v
"""Poincare ball manifold."""
import torch
from Ghypeddings.Poincare.manifolds.base import Manifold
from Ghypeddings.Poincare.utils.math_utils import artanh, tanh
class PoincareBall(Manifold):
"""
PoicareBall Manifold class.
We use the following convention: x0^2 + x1^2 + ... + xd^2 < 1 / c
Note that 1/sqrt(c) is the Poincare ball radius.
"""
def __init__(self, ):
super(PoincareBall, self).__init__()
self.name = 'PoincareBall'
self.min_norm = 1e-15
self.eps = {torch.float32: 4e-3, torch.float64: 1e-5}
def sqdist(self, p1, p2, c):
sqrt_c = c ** 0.5
dist_c = artanh(
sqrt_c * self.mobius_add(-p1, p2, c, dim=-1).norm(dim=-1, p=2, keepdim=False)
)
dist = dist_c * 2 / sqrt_c
return dist ** 2
def _lambda_x(self, x, c):
x_sqnorm = torch.sum(x.data.pow(2), dim=-1, keepdim=True)
return 2 / (1. - c * x_sqnorm).clamp_min(self.min_norm)
def egrad2rgrad(self, p, dp, c):
lambda_p = self._lambda_x(p, c)
dp /= lambda_p.pow(2)
return dp
def proj(self, x, c):
norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), self.min_norm)
maxnorm = (1 - self.eps[x.dtype]) / (c ** 0.5)
cond = norm > maxnorm
projected = x / norm * maxnorm
return torch.where(cond, projected, x)
def proj_tan(self, u, p, c):
return u
def proj_tan0(self, u, c):
return u
def expmap(self, u, p, c):
sqrt_c = c ** 0.5
u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
second_term = (
tanh(sqrt_c / 2 * self._lambda_x(p, c) * u_norm)
* u
/ (sqrt_c * u_norm)
)
gamma_1 = self.mobius_add(p, second_term, c)
return gamma_1
def logmap(self, p1, p2, c):
sub = self.mobius_add(-p1, p2, c)
sub_norm = sub.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
lam = self._lambda_x(p1, c)
sqrt_c = c ** 0.5
return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm
def expmap0(self, u, c):
sqrt_c = c ** 0.5
u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), self.min_norm)
gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
return gamma_1
def logmap0(self, p, c):
sqrt_c = c ** 0.5
p_norm = p.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
scale = 1. / sqrt_c * artanh(sqrt_c * p_norm) / p_norm
return scale * p
def mobius_add(self, x, y, c, dim=-1):
x2 = x.pow(2).sum(dim=dim, keepdim=True)
y2 = y.pow(2).sum(dim=dim, keepdim=True)
xy = (x * y).sum(dim=dim, keepdim=True)
num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
return num / denom.clamp_min(self.min_norm)
def mobius_matvec(self, m, x, c):
sqrt_c = c ** 0.5
x_norm = x.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
mx = x @ m.transpose(-1, -2)
mx_norm = mx.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c)
cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8)
res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
res = torch.where(cond, res_0, res_c)
return res
def init_weights(self, w, c, irange=1e-5):
w.data.uniform_(-irange, irange)
return w
def _gyration(self, u, v, w, c, dim: int = -1):
u2 = u.pow(2).sum(dim=dim, keepdim=True)
v2 = v.pow(2).sum(dim=dim, keepdim=True)
uv = (u * v).sum(dim=dim, keepdim=True)
uw = (u * w).sum(dim=dim, keepdim=True)
vw = (v * w).sum(dim=dim, keepdim=True)
c2 = c ** 2
a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw
b = -c2 * vw * u2 - c * uw
d = 1 + 2 * c * uv + c2 * u2 * v2
return w + 2 * (a * u + b * v) / d.clamp_min(self.min_norm)
def inner(self, x, c, u, v=None, keepdim=False):
if v is None:
v = u
lambda_x = self._lambda_x(x, c)
return lambda_x ** 2 * (u * v).sum(dim=-1, keepdim=keepdim)
def ptransp(self, x, y, u, c):
lambda_x = self._lambda_x(x, c)
lambda_y = self._lambda_x(y, c)
return self._gyration(y, -x, u, c) * lambda_x / lambda_y
def ptransp_(self, x, y, u, c):
lambda_x = self._lambda_x(x, c)
lambda_y = self._lambda_x(y, c)
return self._gyration(y, -x, u, c) * lambda_x / lambda_y
def ptransp0(self, x, u, c):
lambda_x = self._lambda_x(x, c)
return 2 * u / lambda_x.clamp_min(self.min_norm)
def to_hyperboloid(self, x, c):
K = 1./ c
sqrtK = K ** 0.5
sqnorm = torch.norm(x, p=2, dim=1, keepdim=True) ** 2
return sqrtK * torch.cat([K + sqnorm, 2 * sqrtK * x], dim=1) / (K - sqnorm)
"""Base model class."""
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
import torch
import torch.nn as nn
import torch.nn.functional as F
import Ghypeddings.Poincare.manifolds as manifolds
import Ghypeddings.Poincare.models.encoders as encoders
from Ghypeddings.Poincare.models.decoders import model2decoder
from Ghypeddings.Poincare.utils.eval_utils import acc_f1
class BaseModel(nn.Module):
"""
Base model for graph embedding tasks.
"""
def __init__(self, args):
super(BaseModel, self).__init__()
self.manifold_name = 'PoincareBall'
self.c = torch.tensor([1.0])
if not args.cuda == -1:
self.c = self.c.to(args.device)
self.manifold = getattr(manifolds, self.manifold_name)()
self.nnodes = args.n_nodes
self.encoder = getattr(encoders, 'Shallow')(self.c, args)
def encode(self, x):
h = self.encoder.encode(x)
return h
def compute_metrics(self, embeddings, data, split):
raise NotImplementedError
def init_metric_dict(self):
raise NotImplementedError
def has_improved(self, m1, m2):
raise NotImplementedError
class NCModel(BaseModel):
"""
Base model for node classification task.
"""
def __init__(self, args):
super(NCModel, self).__init__(args)
self.decoder = model2decoder(1.0, args)
if args.n_classes > 2:
self.f1_average = 'micro'
else:
self.f1_average = 'binary'
self.weights = torch.Tensor([1.] * args.n_classes)
if not args.cuda == -1:
self.weights = self.weights.to(args.device)
def decode(self, h, idx):
output = self.decoder.decode(h)
return F.log_softmax(output[idx], dim=1)
def compute_metrics(self, embeddings, data, split):
idx = data[f'idx_{split}']
output = self.decode(embeddings, idx)
loss = F.nll_loss(output, data['labels'][idx], self.weights)
acc, f1,recall,precision,roc_auc = acc_f1(output, data['labels'][idx], average=self.f1_average)
metrics = {'loss': loss, 'acc': acc, 'f1': f1,'recall':recall,'precision':precision,'roc_auc':roc_auc}
return metrics
def init_metric_dict(self):
return {'acc': -1, 'f1': -1}
def has_improved(self, m1, m2):
return m1["f1"] < m2["f1"]
\ No newline at end of file
"""Graph decoders."""
import Ghypeddings.Poincare.manifolds as manifolds
import torch.nn as nn
import torch.nn.functional as F
from Ghypeddings.Poincare.layers.layers import Linear
import torch
class Decoder(nn.Module):
"""
Decoder abstract class for node classification tasks.
"""
def __init__(self, c):
super(Decoder, self).__init__()
self.c = c
def decode(self, x):
probs = self.cls.forward(x)
return probs
class LinearDecoder(Decoder):
"""
MLP Decoder for Hyperbolic/Euclidean node classification models.
"""
def __init__(self, c, args):
super(LinearDecoder, self).__init__(c)
self.manifold = getattr(manifolds, 'PoincareBall')()
self.input_dim = args.dim + args.feat_dim
self.output_dim = args.n_classes
self.bias = True
self.cls = Linear(self.input_dim, self.output_dim, args.dropout, lambda x: x, self.bias)
def decode(self, x):
h = self.manifold.proj_tan0(self.manifold.logmap0(x, c=self.c), c=self.c)
return super(LinearDecoder, self).decode(h)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}, c={}'.format(
self.input_dim, self.output_dim, self.bias, self.c
)
model2decoder = LinearDecoder
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment