diff --git a/trainer/utils/utils.py b/trainer/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bcf9a481407696d73f30fe1dde279154a05702b3 --- /dev/null +++ b/trainer/utils/utils.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +from functools import partial +import numpy as np +import random +import torch.optim as optim + + +def create_optimizer(opt, model, lr, weight_decay): + opt_lower = opt.lower() + parameters = model.parameters() + opt_args = dict(lr=lr, weight_decay=weight_decay) + optimizer = None + opt_split = opt_lower.split("_") + opt_lower = opt_split[-1] + if opt_lower == "adam": + optimizer = optim.Adam(parameters, **opt_args) + elif opt_lower == "adamw": + optimizer = optim.AdamW(parameters, **opt_args) + elif opt_lower == "adadelta": + optimizer = optim.Adadelta(parameters, **opt_args) + elif opt_lower == "radam": + optimizer = optim.RAdam(parameters, **opt_args) + elif opt_lower == "sgd": + opt_args["momentum"] = 0.9 + return optim.SGD(parameters, **opt_args) + else: + assert False and "Invalid optimizer" + return optimizer + + +def random_shuffle(x, y): + idx = list(range(len(x))) + random.shuffle(idx) + return x[idx], y[idx] + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.determinstic = True + + +def create_activation(name): + if name == "relu": + return nn.ReLU() + elif name == "gelu": + return nn.GELU() + elif name == "prelu": + return nn.PReLU() + elif name is None: + return nn.Identity() + elif name == "elu": + return nn.ELU() + else: + raise NotImplementedError(f"{name} is not implemented.") + + +def create_norm(name): + if name == "layernorm": + return nn.LayerNorm + elif name == "batchnorm": + return nn.BatchNorm1d + elif name == "graphnorm": + return partial(NormLayer, norm_type="groupnorm") + else: + return None + + +class NormLayer(nn.Module): + def __init__(self, hidden_dim, norm_type): + super().__init__() + if norm_type == "batchnorm": + self.norm = nn.BatchNorm1d(hidden_dim) + elif norm_type == "layernorm": + self.norm = nn.LayerNorm(hidden_dim) + elif norm_type == "graphnorm": + self.norm = norm_type + self.weight = nn.Parameter(torch.ones(hidden_dim)) + self.bias = nn.Parameter(torch.zeros(hidden_dim)) + + self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) + else: + raise NotImplementedError + + def forward(self, graph, x): + tensor = x + if self.norm is not None and type(self.norm) != str: + return self.norm(tensor) + elif self.norm is None: + return tensor + + batch_list = graph.batch_num_nodes + batch_size = len(batch_list) + batch_list = torch.Tensor(batch_list).long().to(tensor.device) + batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) + batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor) + mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) + mean = mean.scatter_add_(0, batch_index, tensor) + mean = (mean.T / batch_list).T + mean = mean.repeat_interleave(batch_list, dim=0) + + sub = tensor - mean * self.mean_scale + + std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) + std = std.scatter_add_(0, batch_index, sub.pow(2)) + std = ((std.T / batch_list).T + 1e-6).sqrt() + std = std.repeat_interleave(batch_list, dim=0) + return self.weight * sub / std + self.bias