From a6f99a87c7d7bcfff4b76ff9b54382f7b13d5a19 Mon Sep 17 00:00:00 2001
From: Abderaouf Gacem <gcmabderaouf@gmail.com>
Date: Mon, 13 Feb 2023 11:32:57 +0000
Subject: [PATCH] init

---
 utils.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 68 insertions(+)
 create mode 100644 utils.py

diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..ac99db0
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,68 @@
+import torch_geometric.transforms as T
+
+from torch_geometric.datasets import Planetoid, AttributedGraphDataset, Reddit, Flickr
+
+from torch_geometric.loader import GraphSAINTRandomWalkSampler, ClusterData, ClusterLoader, NeighborSampler
+from ForestFireSampler import ForestFireSampler 
+
+
+
+
+def get_dataset(dataset, path='./data/', **kwargs):
+    if dataset in ["Cora", "CiteSeer", "PubMed"] :
+        if 'num_val' in kwargs and 'num_test' in kwargs:
+            num_val, num_test = kwargs['num_val'], kwargs['num_test']
+            if num_val > 1 : 
+                return Planetoid(path, dataset, transform=T.RandomNodeSplit(num_val=num_val, num_test=num_test))
+            else :
+                dataset = Planetoid(path, dataset)
+                num_nodes = dataset.data.x.shape[0]
+                dataset.data = T.RandomNodeSplit(split='train_rest', num_val=int(num_val * num_nodes), num_test=int(num_test * num_nodes))(dataset.data)
+                return dataset
+
+        return Planetoid(path, dataset)
+    elif dataset in ["BlogCatalog", "Facebook", "Twitter", "Wiki"] :
+        if 'num_val' in kwargs and 'num_test' in kwargs:
+            num_val, num_test = kwargs['num_val'], kwargs['num_test'] 
+        else :
+            num_val = num_test = 0.2
+        if num_val > 1 :
+            dataset = AttributedGraphDataset(path, dataset, transform=T.RandomNodeSplit(num_val=num_val, num_test=num_test))
+        else :
+            dataset = AttributedGraphDataset(path, dataset)
+            num_nodes = dataset.data.x.shape[0]
+            dataset.data = T.RandomNodeSplit(split='train_rest', num_val=int(num_val * num_nodes), num_test=int(num_test * num_nodes))(dataset.data)
+        
+        dataset.num_nodes = dataset.data.x.shape[0]
+        
+        return dataset
+    elif dataset == "Reddit":
+        return Reddit(path + dataset)
+    elif dataset == "Flickr" :
+        return Flickr(path + dataset)
+    else :
+        raise NotImplementedError(dataset + " not supported")
+
+
+def get_loader(loader, data, save_dir, batch_size, num_workers, **kwargs):
+    if loader == "cluster_gcn" :
+        cluster_data = ClusterData(data, num_parts=kwargs["num_parts"], recursive=False,
+                            save_dir=save_dir)
+        loader = ClusterLoader(cluster_data, batch_size=batch_size, shuffle=True,
+                                num_workers=num_workers)
+    elif loader == "graph_saint_rw" :
+        loader = GraphSAINTRandomWalkSampler(data, batch_size=batch_size, walk_length=kwargs["walk_length"],
+                                     num_steps=kwargs["num_steps"], sample_coverage=kwargs["sample_coverage"],
+                                     save_dir=save_dir,
+                                     num_workers=num_workers)
+    elif loader == "forest_fire" :
+        loader = ForestFireSampler(data, batch_size=batch_size, p=kwargs["p"], connectivity=kwargs["connectivity"],
+                                     num_steps=kwargs["num_steps"], sample_coverage=kwargs["sample_coverage"],
+                                     save_dir=save_dir,
+                                     num_workers=num_workers)
+    elif loader == "test_loader" :
+        loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=batch_size,
+                                  shuffle=False, num_workers=num_workers)
+
+    return loader
+
-- 
GitLab