-
Kamel Souaid Ferrahi authored0f70617a
data_loader.py 10.11 KiB
import logging
import pickle as pkl
import random
import torch.utils.data as data
from fedml.core import partition_class_samples_with_dirichlet_distribution
import dgl
import networkx as nx
import json
from tqdm import tqdm
import os
import numpy as np
from utils.loaddata import load_rawdata, load_batch_level_dataset, load_entity_level_dataset, load_metadata
class WgetDataset(dgl.data.DGLDataset):
def process(self):
pass
def __init__(self, name):
super(WgetDataset, self).__init__(name=name)
if name == 'wget':
pathattack = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalattack'
pathbenin = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalbenin'
num_graphs_benin = 125
num_graphs_attack = 25
self.graphs = []
self.labels = []
print('Loading {} dataset...'.format(name))
for i in tqdm(range(num_graphs_benin)):
idx = i
g = dgl.from_networkx(
nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))),
node_attrs=['type'],
edge_attrs=['type']
)
self.graphs.append(g)
self.labels.append(0)
for i in tqdm(range(num_graphs_attack)):
idx = i
g = dgl.from_networkx(
nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))),
node_attrs=['type'],
edge_attrs=['type']
)
self.graphs.append(g)
self.labels.append(1)
else:
raise NotImplementedError
def __getitem__(self, i):
return self.graphs[i], self.labels[i]
def __len__(self):
return len(self.graphs)
def darpa_split(name):
device = "cpu"
path = './data/' + name + '/'
metadata = load_metadata(name)
n_train = metadata['n_train']
train_dataset = []
train_labels = []
for i in range(n_train):
g = load_entity_level_dataset(name, 'train', i).to(device)
train_dataset.append(g)
train_labels.append(0)
return (
train_dataset,
train_labels,
[],
[],
[],
[]
)
def create_random_split(name):
dataset = load_rawdata(name)
# Random 80/10/10 split as suggested
train_range = (0, int(0.8 * len(dataset)))
val_range = (
int(0.8 * len(dataset)),
int(0.8 * len(dataset)) + int(0.1 * len(dataset)),
)
test_range = (
int(0.8 * len(dataset)) + int(0.1 * len(dataset)),
len(dataset),
)
all_idxs = list(range(len(dataset)))
random.shuffle(all_idxs)
train_dataset = [
dataset[all_idxs[i]] for i in range(train_range[0], train_range[1])
]
train_labels = [dataset[all_idxs[i]][1] for i in range(train_range[0], train_range[1])]
val_dataset = [
dataset[all_idxs[i]] for i in range(val_range[0], val_range[1])
]
val_labels = [dataset[all_idxs[i]][1] for i in range(val_range[0], val_range[1])]
test_dataset = [
dataset[all_idxs[i]] for i in range(test_range[0], test_range[1])
]
test_labels = [dataset[all_idxs[i]][1] for i in range(test_range[0], test_range[1])]
return (
train_dataset,
train_labels,
val_dataset,
val_labels,
test_dataset,
test_labels,
)
def partition_data_by_sample_size(
args, client_number, name, uniform=True, compact=True
):
if (name == 'wget' or name == 'streamspot'):
(
train_dataset,
train_labels,
val_dataset,
val_labels,
test_dataset,
test_labels,
) = create_random_split(name)
else:
(
train_dataset,
train_labels,
val_dataset,
val_labels,
test_dataset,
test_labels,
) = darpa_split(name)
num_train_samples = len(train_dataset)
num_val_samples = len(val_dataset)
num_test_samples = len(test_dataset)
train_idxs = list(range(num_train_samples))
val_idxs = list(range(num_val_samples))
test_idxs = list(range(num_test_samples))
random.shuffle(train_idxs)
random.shuffle(val_idxs)
random.shuffle(test_idxs)
partition_dicts = [None] * client_number
if uniform:
clients_idxs_train = np.array_split(train_idxs, client_number)
clients_idxs_val = np.array_split(val_idxs, client_number)
clients_idxs_test = np.array_split(test_idxs, client_number)
else:
clients_idxs_train = create_non_uniform_split(
args, train_idxs, client_number, True
)
clients_idxs_val = create_non_uniform_split(
args, val_idxs, client_number, False
)
clients_idxs_test = create_non_uniform_split(
args, test_idxs, client_number, False
)
labels_of_all_clients = []
for client in range(client_number):
client_train_idxs = clients_idxs_train[client]
client_val_idxs = clients_idxs_val[client]
client_test_idxs = clients_idxs_test[client]
train_dataset_client = [
train_dataset[idx] for idx in client_train_idxs
]
train_labels_client = [train_labels[idx] for idx in client_train_idxs]
labels_of_all_clients.append(train_labels_client)
val_dataset_client = [val_dataset[idx] for idx in client_val_idxs]
val_labels_client = [val_labels[idx] for idx in client_val_idxs]
test_dataset_client = [test_dataset[idx] for idx in client_test_idxs]
test_labels_client = [test_labels[idx] for idx in client_test_idxs]
partition_dict = {
"train": train_dataset_client,
"val": val_dataset_client,
"test": test_dataset_client,
}
partition_dicts[client] = partition_dict
global_data_dict = {
"train": train_dataset,
"val": val_dataset,
"test": test_dataset,
}
return global_data_dict, partition_dicts
def load_partition_data(
args,
client_number,
name,
uniform=True,
global_test=True,
compact=True,
normalize_features=False,
normalize_adj=False,
):
global_data_dict, partition_dicts = partition_data_by_sample_size(
args, client_number, name, uniform, compact=compact
)
data_local_num_dict = dict()
train_data_local_dict = dict()
val_data_local_dict = dict()
test_data_local_dict = dict()
# IT IS VERY IMPORTANT THAT THE BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE.
train_data_global = global_data_dict["train"]
val_data_global = global_data_dict["val"]
test_data_global = global_data_dict["test"]
train_data_num = len(global_data_dict["train"])
val_data_num = len(global_data_dict["val"])
test_data_num = len(global_data_dict["test"])
for client in range(client_number):
train_dataset_client = partition_dicts[client]["train"]
val_dataset_client = partition_dicts[client]["val"]
test_dataset_client = partition_dicts[client]["test"]
data_local_num_dict[client] = len(train_dataset_client)
train_data_local_dict[client] = train_dataset_client,
val_data_local_dict[client] = val_dataset_client
test_data_local_dict[client] = (
test_data_global
if global_test
else test_dataset_client
)
logging.info(
"Client idx = {}, local sample number = {}".format(
client, len(train_dataset_client)
)
)
return (
train_data_num,
val_data_num,
test_data_num,
train_data_global,
val_data_global,
test_data_global,
data_local_num_dict,
train_data_local_dict,
val_data_local_dict,
test_data_local_dict,
)
def load_batch_level_dataset_main(name):
dataset = get_data(name)
graph, _ = dataset[0]
node_feature_dim = 0
for g, _ in dataset:
node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
edge_feature_dim = 0
for g, _ in dataset:
edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
node_feature_dim += 1
edge_feature_dim += 1
full_dataset = [i for i in range(len(dataset))]
train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
return {'dataset': dataset,
'train_index': train_dataset,
'full_index': full_dataset,
'n_feat': node_feature_dim,
'e_feat': edge_feature_dim}
class GraphDataset(dgl.data.DGLDataset):
def __init__(self, graph_label_list):
super(GraphDataset, self).__init__(name="wget")
self.graph_label_list = graph_label_list
def __len__(self):
return len(self.graph_label_list)
def __getitem__(self, idx):
graph, label = self.graph_label_list[idx]
# Convert the graph to a DGLGraph to work with DGL
return graph, label
def transform_data(data):
dataset = GraphDataset(data[0])
graph, _ = dataset[0]
node_feature_dim = 0
for g, _ in dataset:
node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
edge_feature_dim = 0
for g, _ in dataset:
edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
node_feature_dim += 1
edge_feature_dim += 1
full_dataset = [i for i in range(len(dataset))]
train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
return {'dataset': dataset,
'train_index': train_dataset,
'full_index': full_dataset,
'n_feat': node_feature_dim,
'e_feat': edge_feature_dim}