Skip to content
Snippets Groups Projects
Commit 0768b6f7 authored by Abderaouf Gacem's avatar Abderaouf Gacem
Browse files

Upload New File

parent a6f99a87
No related branches found
No related tags found
No related merge requests found
import torch
import torch_geometric
from torch_geometric.loader import GraphSAINTSampler
import random
import numpy as np
import networkx as nx
from collections import deque
from typing import Any, Optional, Union
from tqdm import tqdm
class ForestFireSampler(GraphSAINTSampler):
def __init__(
self,
data,
batch_size: int = 100,
p: float = 0.5,
seed: int = 100,
restart_hop_size: int = 10,
connectivity: bool = False,
num_steps: int = 1, sample_coverage: int = 0,
save_dir: Optional[str] = None, log: bool = True, **kwargs
):
self.p = p
self.connectivity = connectivity
self.seed = seed
self._set_seed()
self.restart_hop_size = restart_hop_size
self.data_nx = _to_networkx(data).to_undirected()
super().__init__(data, batch_size, num_steps, sample_coverage,
save_dir, log, **kwargs)
def __compute_norm__(self):
node_count = torch.zeros(self.N, dtype=torch.float)
edge_count = torch.zeros(self.E, dtype=torch.float)
loader = torch.utils.data.DataLoader(self, batch_size=200,
collate_fn=lambda x: x,
num_workers=self.num_workers)
if self.log: # pragma: no cover
pbar = tqdm(total=self.N * self.sample_coverage)
pbar.set_description('Compute normalization')
num_samples = total_sampled_nodes = 0
while total_sampled_nodes < self.N * self.sample_coverage:
for data in loader:
for node_idx, adj in data:
edge_idx = adj.storage.value()
node_count[node_idx] += 1
edge_count[edge_idx] += 1
total_sampled_nodes += node_idx.size(0)
if self.log: # pragma: no cover
pbar.update(node_idx.size(0))
num_samples += self.num_steps
if self.log: # pragma: no cover
pbar.close()
row, _, edge_idx = self.adj.coo()
t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row])
edge_norm = (t / edge_count).clamp_(0, 1e4)
edge_norm[torch.isnan(edge_norm)] = 0.1
node_count[node_count == 0] = 0.1
node_norm = num_samples / node_count / self.N
return node_norm, edge_norm
def _set_seed(self):
random.seed(self.seed)
np.random.seed(self.seed)
def __sample_nodes__(self, batch_size):
self._sampled_nodes = set()
if self.connectivity :
visited = deque()
node_queue = []
while len(self._sampled_nodes) < self.__batch_size__:
if len(node_queue) == 0:
if self.connectivity and len(visited):
seed_node = visited.popleft()
else :
while(1):
seed_node = np.random.randint(0, self.data_nx.number_of_nodes())
if(seed_node not in self._sampled_nodes):
break
node_queue.append(seed_node)
top_node = random.sample(node_queue, 1)[0]
self._sampled_nodes.add(top_node)
neighbors = set(self.data_nx.neighbors(top_node))
unvisited_neighbors = neighbors.difference(self._sampled_nodes)
ratio = np.random.triangular(0, self.p, 1)
count = np.around(len(unvisited_neighbors) * ratio)
if ((self.__batch_size__ - len(self._sampled_nodes)) < count) :
count = self.__batch_size__ - len(self._sampled_nodes)
burned_neighbors = random.sample(unvisited_neighbors, int(count))
if self.connectivity :
visited.extendleft(
unvisited_neighbors.difference(set(burned_neighbors))
)
node_queue.extend(burned_neighbors)
self._sampled_nodes.update(np.where(self.data.val_mask.numpy())[0].tolist())
sampled_graph = self.data_nx.subgraph(self._sampled_nodes)
edges = list(sampled_graph.edges)
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
return edge_index.view(2, -1)
@property
def __filename__(self):
return (f'{self.__class__.__name__.lower()}_{self.p}_'
f'{self.sample_coverage}.pt')
def _to_networkx(
data: 'torch_geometric.data.Data',
to_undirected: Optional[Union[bool, str]] = False,
remove_self_loops: bool = False,
) -> Any:
G = nx.Graph() if to_undirected else nx.DiGraph()
G.add_nodes_from(range(data.num_nodes))
to_undirected = "upper" if to_undirected is True else to_undirected
to_undirected_upper = True if to_undirected == "upper" else False
to_undirected_lower = True if to_undirected == "lower" else False
for (u, v) in data.edge_index.t().tolist():
if to_undirected_upper and u > v:
continue
elif to_undirected_lower and u < v:
continue
if remove_self_loops and u == v:
continue
G.add_edge(u, v)
return G
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment