Skip to content
Snippets Groups Projects
Commit c8d80239 authored by Athmane Mansour Bahar's avatar Athmane Mansour Bahar
Browse files

Upload New File

parent e5813ddb
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn as nn
from functools import partial
import numpy as np
import random
import torch.optim as optim
def create_optimizer(opt, model, lr, weight_decay):
opt_lower = opt.lower()
parameters = model.parameters()
opt_args = dict(lr=lr, weight_decay=weight_decay)
optimizer = None
opt_split = opt_lower.split("_")
opt_lower = opt_split[-1]
if opt_lower == "adam":
optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == "adamw":
optimizer = optim.AdamW(parameters, **opt_args)
elif opt_lower == "adadelta":
optimizer = optim.Adadelta(parameters, **opt_args)
elif opt_lower == "radam":
optimizer = optim.RAdam(parameters, **opt_args)
elif opt_lower == "sgd":
opt_args["momentum"] = 0.9
return optim.SGD(parameters, **opt_args)
else:
assert False and "Invalid optimizer"
return optimizer
def random_shuffle(x, y):
idx = list(range(len(x)))
random.shuffle(idx)
return x[idx], y[idx]
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.determinstic = True
def create_activation(name):
if name == "relu":
return nn.ReLU()
elif name == "gelu":
return nn.GELU()
elif name == "prelu":
return nn.PReLU()
elif name is None:
return nn.Identity()
elif name == "elu":
return nn.ELU()
else:
raise NotImplementedError(f"{name} is not implemented.")
def create_norm(name):
if name == "layernorm":
return nn.LayerNorm
elif name == "batchnorm":
return nn.BatchNorm1d
elif name == "graphnorm":
return partial(NormLayer, norm_type="groupnorm")
else:
return None
class NormLayer(nn.Module):
def __init__(self, hidden_dim, norm_type):
super().__init__()
if norm_type == "batchnorm":
self.norm = nn.BatchNorm1d(hidden_dim)
elif norm_type == "layernorm":
self.norm = nn.LayerNorm(hidden_dim)
elif norm_type == "graphnorm":
self.norm = norm_type
self.weight = nn.Parameter(torch.ones(hidden_dim))
self.bias = nn.Parameter(torch.zeros(hidden_dim))
self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
else:
raise NotImplementedError
def forward(self, graph, x):
tensor = x
if self.norm is not None and type(self.norm) != str:
return self.norm(tensor)
elif self.norm is None:
return tensor
batch_list = graph.batch_num_nodes
batch_size = len(batch_list)
batch_list = torch.Tensor(batch_list).long().to(tensor.device)
batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
mean = mean.scatter_add_(0, batch_index, tensor)
mean = (mean.T / batch_list).T
mean = mean.repeat_interleave(batch_list, dim=0)
sub = tensor - mean * self.mean_scale
std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
std = std.scatter_add_(0, batch_index, sub.pow(2))
std = ((std.T / batch_list).T + 1e-6).sqrt()
std = std.repeat_interleave(batch_list, dim=0)
return self.weight * sub / std + self.bias
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment