From a6f99a87c7d7bcfff4b76ff9b54382f7b13d5a19 Mon Sep 17 00:00:00 2001 From: Abderaouf Gacem <gcmabderaouf@gmail.com> Date: Mon, 13 Feb 2023 11:32:57 +0000 Subject: [PATCH] init --- utils.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 utils.py diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..ac99db0 --- /dev/null +++ b/utils.py @@ -0,0 +1,68 @@ +import torch_geometric.transforms as T + +from torch_geometric.datasets import Planetoid, AttributedGraphDataset, Reddit, Flickr + +from torch_geometric.loader import GraphSAINTRandomWalkSampler, ClusterData, ClusterLoader, NeighborSampler +from ForestFireSampler import ForestFireSampler + + + + +def get_dataset(dataset, path='./data/', **kwargs): + if dataset in ["Cora", "CiteSeer", "PubMed"] : + if 'num_val' in kwargs and 'num_test' in kwargs: + num_val, num_test = kwargs['num_val'], kwargs['num_test'] + if num_val > 1 : + return Planetoid(path, dataset, transform=T.RandomNodeSplit(num_val=num_val, num_test=num_test)) + else : + dataset = Planetoid(path, dataset) + num_nodes = dataset.data.x.shape[0] + dataset.data = T.RandomNodeSplit(split='train_rest', num_val=int(num_val * num_nodes), num_test=int(num_test * num_nodes))(dataset.data) + return dataset + + return Planetoid(path, dataset) + elif dataset in ["BlogCatalog", "Facebook", "Twitter", "Wiki"] : + if 'num_val' in kwargs and 'num_test' in kwargs: + num_val, num_test = kwargs['num_val'], kwargs['num_test'] + else : + num_val = num_test = 0.2 + if num_val > 1 : + dataset = AttributedGraphDataset(path, dataset, transform=T.RandomNodeSplit(num_val=num_val, num_test=num_test)) + else : + dataset = AttributedGraphDataset(path, dataset) + num_nodes = dataset.data.x.shape[0] + dataset.data = T.RandomNodeSplit(split='train_rest', num_val=int(num_val * num_nodes), num_test=int(num_test * num_nodes))(dataset.data) + + dataset.num_nodes = dataset.data.x.shape[0] + + return dataset + elif dataset == "Reddit": + return Reddit(path + dataset) + elif dataset == "Flickr" : + return Flickr(path + dataset) + else : + raise NotImplementedError(dataset + " not supported") + + +def get_loader(loader, data, save_dir, batch_size, num_workers, **kwargs): + if loader == "cluster_gcn" : + cluster_data = ClusterData(data, num_parts=kwargs["num_parts"], recursive=False, + save_dir=save_dir) + loader = ClusterLoader(cluster_data, batch_size=batch_size, shuffle=True, + num_workers=num_workers) + elif loader == "graph_saint_rw" : + loader = GraphSAINTRandomWalkSampler(data, batch_size=batch_size, walk_length=kwargs["walk_length"], + num_steps=kwargs["num_steps"], sample_coverage=kwargs["sample_coverage"], + save_dir=save_dir, + num_workers=num_workers) + elif loader == "forest_fire" : + loader = ForestFireSampler(data, batch_size=batch_size, p=kwargs["p"], connectivity=kwargs["connectivity"], + num_steps=kwargs["num_steps"], sample_coverage=kwargs["sample_coverage"], + save_dir=save_dir, + num_workers=num_workers) + elif loader == "test_loader" : + loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=batch_size, + shuffle=False, num_workers=num_workers) + + return loader + -- GitLab