Skip to content
Snippets Groups Projects
Commit a6f99a87 authored by Abderaouf Gacem's avatar Abderaouf Gacem
Browse files

init

parent 702fd61c
No related branches found
No related tags found
No related merge requests found
utils.py 0 → 100644
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid, AttributedGraphDataset, Reddit, Flickr
from torch_geometric.loader import GraphSAINTRandomWalkSampler, ClusterData, ClusterLoader, NeighborSampler
from ForestFireSampler import ForestFireSampler
def get_dataset(dataset, path='./data/', **kwargs):
if dataset in ["Cora", "CiteSeer", "PubMed"] :
if 'num_val' in kwargs and 'num_test' in kwargs:
num_val, num_test = kwargs['num_val'], kwargs['num_test']
if num_val > 1 :
return Planetoid(path, dataset, transform=T.RandomNodeSplit(num_val=num_val, num_test=num_test))
else :
dataset = Planetoid(path, dataset)
num_nodes = dataset.data.x.shape[0]
dataset.data = T.RandomNodeSplit(split='train_rest', num_val=int(num_val * num_nodes), num_test=int(num_test * num_nodes))(dataset.data)
return dataset
return Planetoid(path, dataset)
elif dataset in ["BlogCatalog", "Facebook", "Twitter", "Wiki"] :
if 'num_val' in kwargs and 'num_test' in kwargs:
num_val, num_test = kwargs['num_val'], kwargs['num_test']
else :
num_val = num_test = 0.2
if num_val > 1 :
dataset = AttributedGraphDataset(path, dataset, transform=T.RandomNodeSplit(num_val=num_val, num_test=num_test))
else :
dataset = AttributedGraphDataset(path, dataset)
num_nodes = dataset.data.x.shape[0]
dataset.data = T.RandomNodeSplit(split='train_rest', num_val=int(num_val * num_nodes), num_test=int(num_test * num_nodes))(dataset.data)
dataset.num_nodes = dataset.data.x.shape[0]
return dataset
elif dataset == "Reddit":
return Reddit(path + dataset)
elif dataset == "Flickr" :
return Flickr(path + dataset)
else :
raise NotImplementedError(dataset + " not supported")
def get_loader(loader, data, save_dir, batch_size, num_workers, **kwargs):
if loader == "cluster_gcn" :
cluster_data = ClusterData(data, num_parts=kwargs["num_parts"], recursive=False,
save_dir=save_dir)
loader = ClusterLoader(cluster_data, batch_size=batch_size, shuffle=True,
num_workers=num_workers)
elif loader == "graph_saint_rw" :
loader = GraphSAINTRandomWalkSampler(data, batch_size=batch_size, walk_length=kwargs["walk_length"],
num_steps=kwargs["num_steps"], sample_coverage=kwargs["sample_coverage"],
save_dir=save_dir,
num_workers=num_workers)
elif loader == "forest_fire" :
loader = ForestFireSampler(data, batch_size=batch_size, p=kwargs["p"], connectivity=kwargs["connectivity"],
num_steps=kwargs["num_steps"], sample_coverage=kwargs["sample_coverage"],
save_dir=save_dir,
num_workers=num_workers)
elif loader == "test_loader" :
loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=batch_size,
shuffle=False, num_workers=num_workers)
return loader
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment