Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
SFFS-GCN
Manage
Activity
Members
Labels
Plan
Issues
0
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Hamida Seba
SFFS-GCN
Commits
0768b6f7
Commit
0768b6f7
authored
2 years ago
by
Abderaouf Gacem
Browse files
Options
Downloads
Patches
Plain Diff
Upload New File
parent
a6f99a87
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
ForestFireSampler.py
+158
-0
158 additions, 0 deletions
ForestFireSampler.py
with
158 additions
and
0 deletions
ForestFireSampler.py
0 → 100644
+
158
−
0
View file @
0768b6f7
import
torch
import
torch_geometric
from
torch_geometric.loader
import
GraphSAINTSampler
import
random
import
numpy
as
np
import
networkx
as
nx
from
collections
import
deque
from
typing
import
Any
,
Optional
,
Union
from
tqdm
import
tqdm
class
ForestFireSampler
(
GraphSAINTSampler
):
def
__init__
(
self
,
data
,
batch_size
:
int
=
100
,
p
:
float
=
0.5
,
seed
:
int
=
100
,
restart_hop_size
:
int
=
10
,
connectivity
:
bool
=
False
,
num_steps
:
int
=
1
,
sample_coverage
:
int
=
0
,
save_dir
:
Optional
[
str
]
=
None
,
log
:
bool
=
True
,
**
kwargs
):
self
.
p
=
p
self
.
connectivity
=
connectivity
self
.
seed
=
seed
self
.
_set_seed
()
self
.
restart_hop_size
=
restart_hop_size
self
.
data_nx
=
_to_networkx
(
data
).
to_undirected
()
super
().
__init__
(
data
,
batch_size
,
num_steps
,
sample_coverage
,
save_dir
,
log
,
**
kwargs
)
def
__compute_norm__
(
self
):
node_count
=
torch
.
zeros
(
self
.
N
,
dtype
=
torch
.
float
)
edge_count
=
torch
.
zeros
(
self
.
E
,
dtype
=
torch
.
float
)
loader
=
torch
.
utils
.
data
.
DataLoader
(
self
,
batch_size
=
200
,
collate_fn
=
lambda
x
:
x
,
num_workers
=
self
.
num_workers
)
if
self
.
log
:
# pragma: no cover
pbar
=
tqdm
(
total
=
self
.
N
*
self
.
sample_coverage
)
pbar
.
set_description
(
'
Compute normalization
'
)
num_samples
=
total_sampled_nodes
=
0
while
total_sampled_nodes
<
self
.
N
*
self
.
sample_coverage
:
for
data
in
loader
:
for
node_idx
,
adj
in
data
:
edge_idx
=
adj
.
storage
.
value
()
node_count
[
node_idx
]
+=
1
edge_count
[
edge_idx
]
+=
1
total_sampled_nodes
+=
node_idx
.
size
(
0
)
if
self
.
log
:
# pragma: no cover
pbar
.
update
(
node_idx
.
size
(
0
))
num_samples
+=
self
.
num_steps
if
self
.
log
:
# pragma: no cover
pbar
.
close
()
row
,
_
,
edge_idx
=
self
.
adj
.
coo
()
t
=
torch
.
empty_like
(
edge_count
).
scatter_
(
0
,
edge_idx
,
node_count
[
row
])
edge_norm
=
(
t
/
edge_count
).
clamp_
(
0
,
1e4
)
edge_norm
[
torch
.
isnan
(
edge_norm
)]
=
0.1
node_count
[
node_count
==
0
]
=
0.1
node_norm
=
num_samples
/
node_count
/
self
.
N
return
node_norm
,
edge_norm
def
_set_seed
(
self
):
random
.
seed
(
self
.
seed
)
np
.
random
.
seed
(
self
.
seed
)
def
__sample_nodes__
(
self
,
batch_size
):
self
.
_sampled_nodes
=
set
()
if
self
.
connectivity
:
visited
=
deque
()
node_queue
=
[]
while
len
(
self
.
_sampled_nodes
)
<
self
.
__batch_size__
:
if
len
(
node_queue
)
==
0
:
if
self
.
connectivity
and
len
(
visited
):
seed_node
=
visited
.
popleft
()
else
:
while
(
1
):
seed_node
=
np
.
random
.
randint
(
0
,
self
.
data_nx
.
number_of_nodes
())
if
(
seed_node
not
in
self
.
_sampled_nodes
):
break
node_queue
.
append
(
seed_node
)
top_node
=
random
.
sample
(
node_queue
,
1
)[
0
]
self
.
_sampled_nodes
.
add
(
top_node
)
neighbors
=
set
(
self
.
data_nx
.
neighbors
(
top_node
))
unvisited_neighbors
=
neighbors
.
difference
(
self
.
_sampled_nodes
)
ratio
=
np
.
random
.
triangular
(
0
,
self
.
p
,
1
)
count
=
np
.
around
(
len
(
unvisited_neighbors
)
*
ratio
)
if
((
self
.
__batch_size__
-
len
(
self
.
_sampled_nodes
))
<
count
)
:
count
=
self
.
__batch_size__
-
len
(
self
.
_sampled_nodes
)
burned_neighbors
=
random
.
sample
(
unvisited_neighbors
,
int
(
count
))
if
self
.
connectivity
:
visited
.
extendleft
(
unvisited_neighbors
.
difference
(
set
(
burned_neighbors
))
)
node_queue
.
extend
(
burned_neighbors
)
self
.
_sampled_nodes
.
update
(
np
.
where
(
self
.
data
.
val_mask
.
numpy
())[
0
].
tolist
())
sampled_graph
=
self
.
data_nx
.
subgraph
(
self
.
_sampled_nodes
)
edges
=
list
(
sampled_graph
.
edges
)
edge_index
=
torch
.
tensor
(
edges
,
dtype
=
torch
.
long
).
t
().
contiguous
()
return
edge_index
.
view
(
2
,
-
1
)
@property
def
__filename__
(
self
):
return
(
f
'
{
self
.
__class__
.
__name__
.
lower
()
}
_
{
self
.
p
}
_
'
f
'
{
self
.
sample_coverage
}
.pt
'
)
def
_to_networkx
(
data
:
'
torch_geometric.data.Data
'
,
to_undirected
:
Optional
[
Union
[
bool
,
str
]]
=
False
,
remove_self_loops
:
bool
=
False
,
)
->
Any
:
G
=
nx
.
Graph
()
if
to_undirected
else
nx
.
DiGraph
()
G
.
add_nodes_from
(
range
(
data
.
num_nodes
))
to_undirected
=
"
upper
"
if
to_undirected
is
True
else
to_undirected
to_undirected_upper
=
True
if
to_undirected
==
"
upper
"
else
False
to_undirected_lower
=
True
if
to_undirected
==
"
lower
"
else
False
for
(
u
,
v
)
in
data
.
edge_index
.
t
().
tolist
():
if
to_undirected_upper
and
u
>
v
:
continue
elif
to_undirected_lower
and
u
<
v
:
continue
if
remove_self_loops
and
u
==
v
:
continue
G
.
add_edge
(
u
,
v
)
return
G
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment