Skip to content
Snippets Groups Projects
data_loader.py 10.11 KiB
import logging
import pickle as pkl
import random

import torch.utils.data as data

from fedml.core import partition_class_samples_with_dirichlet_distribution
import dgl
import networkx as nx
import json
from tqdm import tqdm
import os
import numpy as np
from utils.loaddata import load_rawdata, load_batch_level_dataset, load_entity_level_dataset, load_metadata


class WgetDataset(dgl.data.DGLDataset):
    def process(self):
        pass

    def __init__(self, name):
        super(WgetDataset, self).__init__(name=name)
        if name == 'wget':
            pathattack = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalattack'
            pathbenin = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalbenin'
            num_graphs_benin = 125
            num_graphs_attack = 25
            self.graphs = []
            self.labels = []
            print('Loading {} dataset...'.format(name))
            for i in tqdm(range(num_graphs_benin)):
                idx = i
                g = dgl.from_networkx(
                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))),
                    node_attrs=['type'],
                    edge_attrs=['type']
                )
                self.graphs.append(g)
                self.labels.append(0)
                
            for i in tqdm(range(num_graphs_attack)):
                idx = i
                g = dgl.from_networkx(
                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))),
                    node_attrs=['type'],
                    edge_attrs=['type']
                )
                self.graphs.append(g)
                self.labels.append(1)
        else:
            raise NotImplementedError

    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)

def darpa_split(name):
    device = "cpu"
    path = './data/' + name + '/'
    metadata = load_metadata(name)
    n_train = metadata['n_train']
    train_dataset = []
    train_labels = []
    for i in range(n_train):
        g = load_entity_level_dataset(name, 'train', i).to(device)
        train_dataset.append(g)
        train_labels.append(0)
    
    return (
        train_dataset,
        train_labels,
        [],
        [],
        [],
        []
    )
       

def create_random_split(name):
    dataset = load_rawdata(name)
    # Random 80/10/10 split as suggested 
    train_range = (0, int(0.8 * len(dataset)))
    val_range = (
        int(0.8 * len(dataset)),
        int(0.8 * len(dataset)) + int(0.1 * len(dataset)),
    )
    test_range = (
        int(0.8 * len(dataset)) + int(0.1 * len(dataset)),
        len(dataset),
    )

    all_idxs = list(range(len(dataset)))
    random.shuffle(all_idxs)

    train_dataset = [
        dataset[all_idxs[i]] for i in range(train_range[0], train_range[1])
    ]
   
    train_labels = [dataset[all_idxs[i]][1] for i in range(train_range[0], train_range[1])]

    val_dataset = [
        dataset[all_idxs[i]] for i in range(val_range[0], val_range[1])
    ]
    val_labels = [dataset[all_idxs[i]][1] for i in range(val_range[0], val_range[1])]

    test_dataset  = [
        dataset[all_idxs[i]] for i in range(test_range[0], test_range[1])
    ]
    test_labels = [dataset[all_idxs[i]][1] for i in range(test_range[0], test_range[1])]

    return (
        train_dataset,
        train_labels,
        val_dataset,
        val_labels,
        test_dataset,
        test_labels,
    )



def partition_data_by_sample_size(
    args, client_number, name, uniform=True, compact=True
):
    if (name == 'wget' or name == 'streamspot'):
        (
            train_dataset,
            train_labels,
            val_dataset,
            val_labels,
            test_dataset,
            test_labels,
        ) = create_random_split(name)
    else:
        (
            train_dataset,
            train_labels,
            val_dataset,
            val_labels,
            test_dataset,
            test_labels,
        ) = darpa_split(name)
        
    num_train_samples = len(train_dataset)
    num_val_samples = len(val_dataset)
    num_test_samples = len(test_dataset)

    train_idxs = list(range(num_train_samples))
    val_idxs = list(range(num_val_samples))
    test_idxs = list(range(num_test_samples))

    random.shuffle(train_idxs)
    random.shuffle(val_idxs)
    random.shuffle(test_idxs)

    partition_dicts = [None] * client_number

    if uniform:
        clients_idxs_train = np.array_split(train_idxs, client_number)
        clients_idxs_val = np.array_split(val_idxs, client_number)
        clients_idxs_test = np.array_split(test_idxs, client_number)
    else:
        clients_idxs_train = create_non_uniform_split(
            args, train_idxs, client_number, True
        )
        clients_idxs_val = create_non_uniform_split(
            args, val_idxs, client_number, False
        )
        clients_idxs_test = create_non_uniform_split(
            args, test_idxs, client_number, False
        )

    labels_of_all_clients = []
    for client in range(client_number):
        client_train_idxs = clients_idxs_train[client]
        client_val_idxs = clients_idxs_val[client]
        client_test_idxs = clients_idxs_test[client]

        train_dataset_client = [
            train_dataset[idx] for idx in client_train_idxs
        ]
        train_labels_client = [train_labels[idx] for idx in client_train_idxs]
        labels_of_all_clients.append(train_labels_client)

        val_dataset_client = [val_dataset[idx] for idx in client_val_idxs]
        val_labels_client = [val_labels[idx] for idx in client_val_idxs]

        test_dataset_client = [test_dataset[idx] for idx in client_test_idxs]
        test_labels_client = [test_labels[idx] for idx in client_test_idxs]


        partition_dict = {
            "train": train_dataset_client,
            "val": val_dataset_client,
            "test": test_dataset_client,
        }

        partition_dicts[client] = partition_dict
    global_data_dict = {
        "train": train_dataset,
        "val": val_dataset,
        "test": test_dataset,
    }

    return global_data_dict, partition_dicts

def load_partition_data(
    args,
    client_number,
    name,
    uniform=True,
    global_test=True,
    compact=True,
    normalize_features=False,
    normalize_adj=False,
):
    global_data_dict, partition_dicts = partition_data_by_sample_size(
        args, client_number, name, uniform, compact=compact
    )

    data_local_num_dict = dict()
    train_data_local_dict = dict()
    val_data_local_dict = dict()
    test_data_local_dict = dict()

  

    # IT IS VERY IMPORTANT THAT THE BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE.
    train_data_global =  global_data_dict["train"]
    val_data_global = global_data_dict["val"]
    test_data_global = global_data_dict["test"]
    train_data_num = len(global_data_dict["train"])
    val_data_num = len(global_data_dict["val"])
    test_data_num = len(global_data_dict["test"])

    for client in range(client_number):
        train_dataset_client = partition_dicts[client]["train"]
        val_dataset_client = partition_dicts[client]["val"]
        test_dataset_client = partition_dicts[client]["test"]

        data_local_num_dict[client] = len(train_dataset_client)
        train_data_local_dict[client] = train_dataset_client,
 
        val_data_local_dict[client] = val_dataset_client
  
        test_data_local_dict[client] = (
            test_data_global
            if global_test
            else test_dataset_client
                
        )

        logging.info(
            "Client idx = {}, local sample number = {}".format(
                client, len(train_dataset_client)
            )
        )

    return (
        train_data_num,
        val_data_num,
        test_data_num,
        train_data_global,
        val_data_global,
        test_data_global,
        data_local_num_dict,
        train_data_local_dict,
        val_data_local_dict,
        test_data_local_dict,
    )



def load_batch_level_dataset_main(name):
    dataset = get_data(name)
    graph, _ = dataset[0]
    node_feature_dim = 0
    for g, _ in dataset:
        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
    edge_feature_dim = 0
    for g, _ in dataset:
        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
    node_feature_dim += 1
    edge_feature_dim += 1
    full_dataset = [i for i in range(len(dataset))]
    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))

    return {'dataset': dataset,
            'train_index': train_dataset,
            'full_index': full_dataset,
            'n_feat': node_feature_dim,
            'e_feat': edge_feature_dim}
            
            
class GraphDataset(dgl.data.DGLDataset):
    def __init__(self, graph_label_list):
        super(GraphDataset, self).__init__(name="wget")
        self.graph_label_list = graph_label_list

    def __len__(self):
        return len(self.graph_label_list)

    def __getitem__(self, idx):
        graph, label = self.graph_label_list[idx]
        # Convert the graph to a DGLGraph to work with DGL
        return graph, label

            
def transform_data(data):
    dataset = GraphDataset(data[0])
    graph, _ = dataset[0]
    node_feature_dim = 0
    for g, _ in dataset:
        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
    edge_feature_dim = 0
    for g, _ in dataset:
        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
    node_feature_dim += 1
    edge_feature_dim += 1
    full_dataset = [i for i in range(len(dataset))]
    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))

    return {'dataset': dataset,
            'train_index': train_dataset,
            'full_index': full_dataset,
            'n_feat': node_feature_dim,
            'e_feat': edge_feature_dim}