Skip to content
Snippets Groups Projects
Commit b4e80745 authored by Abderaouf Gacem's avatar Abderaouf Gacem
Browse files

Upload New File

parent 0768b6f7
No related branches found
No related tags found
No related merge requests found
models.py 0 → 100644
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from tqdm import tqdm
from torch_geometric.nn import GraphConv, SAGEConv, GATConv, GCNConv
class GraphSaintPyGNet(torch.nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels):
super().__init__()
self.convs = ModuleList(
[GraphConv(in_channels, hidden_channels),
GraphConv(hidden_channels, hidden_channels),
GraphConv(hidden_channels, hidden_channels)])
self.lin = torch.nn.Linear(3 * hidden_channels, out_channels)
def set_aggr(self, aggr):
for conv in self.convs :
conv.aggr = aggr
def forward(self, x0, edge_index, edge_weight=None):
x = x0
x_all = torch.Tensor().to(x0.device)
for i, conv in enumerate(self.convs):
x = conv(x, edge_index, edge_weight)
x_all = torch.cat([x_all, x], dim=-1)
if i != len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.2, training=self.training)
x = self.lin(x_all)
return x.log_softmax(dim=-1)
def inference(self, x_all, subgraph_loader, device):
pbar = tqdm(total=x_all.size(0) * len(self.convs))
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
x_concat_layers = torch.Tensor().to(device)
for i, conv in enumerate(self.convs):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = conv((x, x_target), edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
x_concat_layers = torch.cat([x_concat_layers, x_all.to(device)], dim=-1)
x = self.lin(x_concat_layers)
pbar.close()
return F.log_softmax(x, dim=-1)
class ClusterGCNPyGNet(torch.nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels):
super().__init__()
self.convs = ModuleList(
[SAGEConv(in_channels, hidden_channels),
SAGEConv(hidden_channels, out_channels)])
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return F.log_softmax(x, dim=-1)
def inference(self, x_all, subgraph_loader, device):
pbar = tqdm(total=x_all.size(0) * len(self.convs))
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
for i, conv in enumerate(self.convs):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = conv((x, x_target), edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return F.log_softmax(x_all, dim=-1)
class GAT(torch.nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels, heads):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6)
# On the Pubmed dataset, use `heads` output heads in `conv2`.
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1,
concat=False, dropout=0.6)
def forward(self, x, edge_index, edge_weight=None):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index, edge_weight))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index, edge_weight)
return F.log_softmax(x, dim=-1)
class GCN(torch.nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels,):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels, cached=True,
normalize=True)
self.conv2 = GCNConv(hidden_channels, out_channels, cached=True,
normalize=True)
def forward(self, x, edge_index, edge_weight=None):
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv1(x, edge_index, edge_weight).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index, edge_weight)
return F.log_softmax(x, dim=-1)
class GraphSaintGCN(torch.nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels):
super().__init__()
self.convs = ModuleList(
[GCNConv(in_channels, hidden_channels, cached=True,
normalize=True),
GCNConv(hidden_channels, hidden_channels, cached=True,
normalize=True),
GCNConv(hidden_channels, out_channels, cached=True,
normalize=True)])
self.lin = torch.nn.Linear(3 * hidden_channels, out_channels)
def set_aggr(self, aggr):
for conv in self.convs :
conv.aggr = aggr
def forward(self, x0, edge_index, edge_weight=None):
x = x0
x_all = torch.Tensor().to(x0.device)
for i, conv in enumerate(self.convs):
x = conv(x, edge_index, edge_weight)
x_all = torch.cat([x_all, x], dim=-1)
if i != len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.2, training=self.training)
x = self.lin(x_all)
return x.log_softmax(dim=-1)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment