Skip to content
Snippets Groups Projects
Commit 0f70617a authored by Kamel Souaid Ferrahi's avatar Kamel Souaid Ferrahi
Browse files

first commit

parent 57b22966
No related branches found
No related tags found
No related merge requests found
Showing
with 758 additions and 70 deletions
# Fedhe-graph
# FEDHE-Graph
Welcome to the official repository housing the FEDHE-Graph implementation for Magic! This repository provides you with the necessary tools and resources to leverage federated learning techniques within the context of Magic, a comprehensive framework for federated learning research.
![architecture](https://gitlab.liris.cnrs.fr/gladis/graphfl/-/raw/main/assets/archiFedHe.png?ref_type=heads)
## Getting started
To make it easy for you to get started with GitLab, here's a list of recommended next steps.
Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)!
Original project: https://github.com/FDUDSDE/MAGIC
## Add your files
## Environment Setup
- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command:
The command are used in an environnement that consist of Ubuntu 22.04 with miniconda installed
```
cd existing_repo
git remote add origin https://gitlab.liris.cnrs.fr/gladis/fedhe-graph.git
git branch -M main
git push -uf origin main
```
## Integrate with your tools
- [ ] [Set up project integrations](https://gitlab.liris.cnrs.fr/gladis/fedhe-graph/-/settings/integrations)
## Collaborate with your team
- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/)
- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically)
- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/)
- [ ] [Set auto-merge](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html)
## Test and Deploy
Use the built-in continuous integration in GitLab.
- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html)
- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing (SAST)](https://docs.gitlab.com/ee/user/application_security/sast/)
- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html)
- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/)
- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html)
***
# Editing this README
When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thanks to [makeareadme.com](https://www.makeareadme.com/) for this template.
Original project: https://github.com/FDUDSDE/MAGIC
## Suggestions for a good README
First create the conda environnement for fedml with MPI support
Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information.
## Name
Choose a self-explaining name for your project.
```
conda create --name fedml-pip python=3.8
conda activate fedml-pip
conda install --name fedml-pip pip
conda install -c conda-forge mpi4py openmpi
pip install "fedml[MPI]"
```
## Description
Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors.
Clone the MAGIC FedML project onto your current folder
## Badges
On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge.
```
git clone https://github.com/kamelferrahi/MAGIC_FEDERATED_FedML
```
## Visuals
Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method.
Install the necessary packages for Magic to run
## Installation
Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection.
```
conda install -c conda-forge aiohttp=3.9.1 aiosignal=1.3.1 anyio=4.2.0 attrdict=2.0.1 attrs=23.2.0 blis=0.7.11 boto3=1.34.12 botocore=1.34.12 brotli=1.1.0 catalogue=2.0.10 certifi=2023.11.17 chardet=5.2.0 charset-normalizer=3.3.2 click=8.1.7 cloudpathlib=0.16.0 confection=0.1.4 contourpy=1.2.0 cycler=0.12.1 cymem=2.0.8 dgl=1.1.3 dill=0.3.7 fastapi=0.92.0 fedml=0.8.13.post2 filelock=3.13.1 fonttools=4.47.0 frozenlist=1.4.1 fsspec=2023.12.2 gensim=4.3.2 gevent=23.9.1 geventhttpclient=2.0.9 gitdb=4.0.11 GitPython=3.1.40 GPUtil=1.4.0 graphviz=0.8.4 greenlet=3.0.3 h11=0.14.0 h5py=3.10.0 httpcore=1.0.2 httpx=0.26.0 idna=3.6 Jinja2=3.1.2 jmespath=1.0.1 joblib=1.3.2 kiwisolver=1.4.5 langcodes=3.3.0 MarkupSafe=2.1.3 matplotlib=3.8.2 mpi4py=3.1.3 mpmath=1.3.0 multidict=6.0.4 multiprocess=0.70.15 murmurhash=1.0.10 networkx=2.8.8 ntplib=0.4.0 numpy=1.26.3 nvidia-cublas-cu12=12.1.3.1 nvidia-cuda-cupti-cu12=12.1.105 nvidia-cuda-nvrtc-cu12=12.1.105 nvidia-cuda-runtime-cu12=12.1.105 nvidia-cudnn-cu12=8.9.2.26 nvidia-cufft-cu12=11.0.2.54 nvidia-curand-cu12=10.3.2.106 nvidia-cusolver-cu12=11.4.5.107 nvidia-cusparse-cu12=12.1.0.106 nvidia-nccl-cu12=2.18.1 nvidia-nvjitlink-cu12=12.3.101 nvidia-nvtx-cu12=12.1.105 onnx=1.15.0 packaging=23.2 paho-mqtt=1.6.1 pandas=2.1.4 pathtools=0.1.2 pillow=10.2.0 preshed=3.0.9 prettytable=3.9.0 promise=2.3 protobuf=3.20.3 psutil=5.9.7 py-machineid=0.4.6 pydantic=1.10.13 pyparsing=3.1.1 python-dateutil=2.8.2 python-rapidjson=1.14 pytz=2023.3.post1 PyYAML=6.0.1 redis=5.0.1 requests=2.31.0 s3transfer=0.10.0 scikit-learn=1.3.2 scipy=1.11.4 sentry-sdk=1.39.1 setproctitle=1.3.3 shortuuid=1.0.11 six=1.16.0 smart-open=6.3.0 smmap=5.0.1 sniffio=1.3.0 spacy=3.7.2 spacy-legacy=3.0.12 spacy-loggers=1.0.5 SQLAlchemy=2.0.25 srsly=2.4.8 starlette=0.25.0 sympy=1.12 thinc=8.2.2 threadpoolctl=3.2.0 torch=2.1.2 torch-cluster=1.6.3 torch-scatter=2.1.2 torch-sparse=0.6.18 torch-spline-conv=1.2.2 torch_geometric=2.4.0 torchvision=0.16.2 tqdm=4.66.1 triton=2.1.0 tritonclient=2.41.0 typer=0.9.0 typing_extensions=4.9.0 tzdata=2023.4 tzlocal=5.2 urllib3=2.0.7 uvicorn=0.25.0 wandb=0.13.2 wasabi=1.1.2 wcwidth=0.2.12 weasel=0.3.4 websocket-client=1.7.0 wget=3.2 yarl=1.9.4 zope.event=5.0 zope.interface=6.1
```
## Usage
Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README.
Finally run the federated algorithm using the mpi command
## Support
Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc.
```
hostname > mpi_host_file
mpirun -np 4 -hostfile mpi_host_file --oversubscribe python main.py --cf fedml_config.yaml
```
## Roadmap
If you have ideas for releases in the future, it is a good idea to list them in the README.
## Federated learning parameters
You can adjust federated learning parameters in the `fedml_config.yaml` file.
Parameters such as aggregation algorithm, number of clients, and clients per round for aggregation can be modified:
```
train_args:
federated_optimizer: "FedAvg"
client_id_list:
client_num_in_total: 4
client_num_per_round: 4
```
## Contributing
State if you are open to contributions and what your requirements are for accepting them.
The algorithm tested are `FedAvg`, `FedProx` and `FedOpt`
For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self.
## Datasets
The experiments utilize datasets similar to those in the original Magic project. To change datasets, edit the `fedml_config.yaml` file:
```
data_args:
dataset: "wget"
data_cache_dir: ~/fedgraphnn_data/
part_file: ~/fedgraphnn_data/partition
```
You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser.
Feel free to explore and modify these settings according to your specific requirements!
## Authors and acknowledgment
Show your appreciation to those who have contributed to the project.
## License
For open source projects, say how it is licensed.
## Project status
If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers.
a.yaml 0 → 100644
common_args:
training_type: "cross_silo"
scenario: "horizontal"
using_mlops: false
config_version: release
name: "exp"
project: "runs/train"
exist_ok: false
random_seed: 0
common_args:
training_type: "simulation"
random_seed: 0
comm_args:
backend: "MPI"
is_mobile: 0
data_args:
dataset: "clintox"
data_cache_dir: ~/fedgraphnn_data/
part_file: ~/fedgraphnn_data/partition
partition_method: "hetero"
partition_alpha: 0.5
model_args:
model: "graphsage"
hidden_size: 32
node_embedding_dim: 32
graph_embedding_dim: 64
readout_hidden_dim: 64
alpha: 0.2
num_heads: 2
dropout: 0.3
normalize_features: False
normalize_adjacency: False
sparse_adjacency: False
model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
global_model_file_path: "./model_file_cache/global_model.pt"
environment_args:
bootstrap: config/bootstrap.sh
train_args:
federated_optimizer: "FedAvg"
client_id_list:
client_num_in_total: 3
client_num_per_round: 3
comm_round: 100
epochs: 5
batch_size: 64
client_optimizer: sgd
learning_rate: 0.03
weight_decay: 0.001
metric: "prc-auc"
server_optimizer: sgd
lr: 0.001
server_lr: 0.001
wd: 0.001
ci: 0
server_momentum: 0.9
validation_args:
frequency_of_the_test: 1
device_args:
worker_num: 3
using_gpu: false
gpu_mapping_file: config/gpu_mapping.yaml
gpu_mapping_key: mapping_fedgraphnn_sp
comm_args:
backend: "MQTT_S3"
mqtt_config_path: config/mqtt_config.yaml
s3_config_path: config/s3_config.yaml
tracking_args:
# When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
enable_wandb: false
wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
wandb_project: fedml
wandb_name: fedml_torch_moleculenet
a.yml 0 → 100644
common_args:
training_type: "simulation"
random_seed: 0
data_args:
dataset: "clintox"
data_cache_dir: ~/fedgraphnn_data/
part_file: ~/fedgraphnn_data/partition
partition_method: "hetero"
partition_alpha: 0.5
model_args:
model: "graphsage"
hidden_size: 32
node_embedding_dim: 32
graph_embedding_dim: 64
readout_hidden_dim: 64
alpha: 0.2
num_heads: 2
dropout: 0.3
normalize_features: False
normalize_adjacency: False
sparse_adjacency: False
model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
global_model_file_path: "./model_file_cache/global_model.pt"
environment_args:
bootstrap: config/bootstrap.sh
train_args:
federated_optimizer: "FedAvg"
client_id_list:
client_num_in_total: 3
client_num_per_round: 3
comm_round: 100
epochs: 5
batch_size: 64
client_optimizer: sgd
learning_rate: 0.03
weight_decay: 0.001
metric: "prc-auc"
server_optimizer: sgd
lr: 0.001
server_lr: 0.001
wd: 0.001
ci: 0
server_momentum: 0.9
validation_args:
frequency_of_the_test: 1
device_args:
worker_num: 2
using_gpu: false
gpu_mapping_file: config/gpu_mapping.yaml
gpu_mapping_key: mapping_fedgraphnn_sp
comm_args:
backend: "MPI"
is_mobile: 0
tracking_args:
# When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
enable_wandb: false
wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
wandb_project: fedml
wandb_name: fedml_torch_moleculenet
assets/archiFedHe.png

980 KiB

File added
File added
import logging
import pickle as pkl
import random
import torch.utils.data as data
from fedml.core import partition_class_samples_with_dirichlet_distribution
import dgl
import networkx as nx
import json
from tqdm import tqdm
import os
import numpy as np
from utils.loaddata import load_rawdata, load_batch_level_dataset, load_entity_level_dataset, load_metadata
class WgetDataset(dgl.data.DGLDataset):
def process(self):
pass
def __init__(self, name):
super(WgetDataset, self).__init__(name=name)
if name == 'wget':
pathattack = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalattack'
pathbenin = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalbenin'
num_graphs_benin = 125
num_graphs_attack = 25
self.graphs = []
self.labels = []
print('Loading {} dataset...'.format(name))
for i in tqdm(range(num_graphs_benin)):
idx = i
g = dgl.from_networkx(
nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))),
node_attrs=['type'],
edge_attrs=['type']
)
self.graphs.append(g)
self.labels.append(0)
for i in tqdm(range(num_graphs_attack)):
idx = i
g = dgl.from_networkx(
nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))),
node_attrs=['type'],
edge_attrs=['type']
)
self.graphs.append(g)
self.labels.append(1)
else:
raise NotImplementedError
def __getitem__(self, i):
return self.graphs[i], self.labels[i]
def __len__(self):
return len(self.graphs)
def darpa_split(name):
device = "cpu"
path = './data/' + name + '/'
metadata = load_metadata(name)
n_train = metadata['n_train']
train_dataset = []
train_labels = []
for i in range(n_train):
g = load_entity_level_dataset(name, 'train', i).to(device)
train_dataset.append(g)
train_labels.append(0)
return (
train_dataset,
train_labels,
[],
[],
[],
[]
)
def create_random_split(name):
dataset = load_rawdata(name)
# Random 80/10/10 split as suggested
train_range = (0, int(0.8 * len(dataset)))
val_range = (
int(0.8 * len(dataset)),
int(0.8 * len(dataset)) + int(0.1 * len(dataset)),
)
test_range = (
int(0.8 * len(dataset)) + int(0.1 * len(dataset)),
len(dataset),
)
all_idxs = list(range(len(dataset)))
random.shuffle(all_idxs)
train_dataset = [
dataset[all_idxs[i]] for i in range(train_range[0], train_range[1])
]
train_labels = [dataset[all_idxs[i]][1] for i in range(train_range[0], train_range[1])]
val_dataset = [
dataset[all_idxs[i]] for i in range(val_range[0], val_range[1])
]
val_labels = [dataset[all_idxs[i]][1] for i in range(val_range[0], val_range[1])]
test_dataset = [
dataset[all_idxs[i]] for i in range(test_range[0], test_range[1])
]
test_labels = [dataset[all_idxs[i]][1] for i in range(test_range[0], test_range[1])]
return (
train_dataset,
train_labels,
val_dataset,
val_labels,
test_dataset,
test_labels,
)
def partition_data_by_sample_size(
args, client_number, name, uniform=True, compact=True
):
if (name == 'wget' or name == 'streamspot'):
(
train_dataset,
train_labels,
val_dataset,
val_labels,
test_dataset,
test_labels,
) = create_random_split(name)
else:
(
train_dataset,
train_labels,
val_dataset,
val_labels,
test_dataset,
test_labels,
) = darpa_split(name)
num_train_samples = len(train_dataset)
num_val_samples = len(val_dataset)
num_test_samples = len(test_dataset)
train_idxs = list(range(num_train_samples))
val_idxs = list(range(num_val_samples))
test_idxs = list(range(num_test_samples))
random.shuffle(train_idxs)
random.shuffle(val_idxs)
random.shuffle(test_idxs)
partition_dicts = [None] * client_number
if uniform:
clients_idxs_train = np.array_split(train_idxs, client_number)
clients_idxs_val = np.array_split(val_idxs, client_number)
clients_idxs_test = np.array_split(test_idxs, client_number)
else:
clients_idxs_train = create_non_uniform_split(
args, train_idxs, client_number, True
)
clients_idxs_val = create_non_uniform_split(
args, val_idxs, client_number, False
)
clients_idxs_test = create_non_uniform_split(
args, test_idxs, client_number, False
)
labels_of_all_clients = []
for client in range(client_number):
client_train_idxs = clients_idxs_train[client]
client_val_idxs = clients_idxs_val[client]
client_test_idxs = clients_idxs_test[client]
train_dataset_client = [
train_dataset[idx] for idx in client_train_idxs
]
train_labels_client = [train_labels[idx] for idx in client_train_idxs]
labels_of_all_clients.append(train_labels_client)
val_dataset_client = [val_dataset[idx] for idx in client_val_idxs]
val_labels_client = [val_labels[idx] for idx in client_val_idxs]
test_dataset_client = [test_dataset[idx] for idx in client_test_idxs]
test_labels_client = [test_labels[idx] for idx in client_test_idxs]
partition_dict = {
"train": train_dataset_client,
"val": val_dataset_client,
"test": test_dataset_client,
}
partition_dicts[client] = partition_dict
global_data_dict = {
"train": train_dataset,
"val": val_dataset,
"test": test_dataset,
}
return global_data_dict, partition_dicts
def load_partition_data(
args,
client_number,
name,
uniform=True,
global_test=True,
compact=True,
normalize_features=False,
normalize_adj=False,
):
global_data_dict, partition_dicts = partition_data_by_sample_size(
args, client_number, name, uniform, compact=compact
)
data_local_num_dict = dict()
train_data_local_dict = dict()
val_data_local_dict = dict()
test_data_local_dict = dict()
# IT IS VERY IMPORTANT THAT THE BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE.
train_data_global = global_data_dict["train"]
val_data_global = global_data_dict["val"]
test_data_global = global_data_dict["test"]
train_data_num = len(global_data_dict["train"])
val_data_num = len(global_data_dict["val"])
test_data_num = len(global_data_dict["test"])
for client in range(client_number):
train_dataset_client = partition_dicts[client]["train"]
val_dataset_client = partition_dicts[client]["val"]
test_dataset_client = partition_dicts[client]["test"]
data_local_num_dict[client] = len(train_dataset_client)
train_data_local_dict[client] = train_dataset_client,
val_data_local_dict[client] = val_dataset_client
test_data_local_dict[client] = (
test_data_global
if global_test
else test_dataset_client
)
logging.info(
"Client idx = {}, local sample number = {}".format(
client, len(train_dataset_client)
)
)
return (
train_data_num,
val_data_num,
test_data_num,
train_data_global,
val_data_global,
test_data_global,
data_local_num_dict,
train_data_local_dict,
val_data_local_dict,
test_data_local_dict,
)
def load_batch_level_dataset_main(name):
dataset = get_data(name)
graph, _ = dataset[0]
node_feature_dim = 0
for g, _ in dataset:
node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
edge_feature_dim = 0
for g, _ in dataset:
edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
node_feature_dim += 1
edge_feature_dim += 1
full_dataset = [i for i in range(len(dataset))]
train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
return {'dataset': dataset,
'train_index': train_dataset,
'full_index': full_dataset,
'n_feat': node_feature_dim,
'e_feat': edge_feature_dim}
class GraphDataset(dgl.data.DGLDataset):
def __init__(self, graph_label_list):
super(GraphDataset, self).__init__(name="wget")
self.graph_label_list = graph_label_list
def __len__(self):
return len(self.graph_label_list)
def __getitem__(self, idx):
graph, label = self.graph_label_list[idx]
# Convert the graph to a DGLGraph to work with DGL
return graph, label
def transform_data(data):
dataset = GraphDataset(data[0])
graph, _ = dataset[0]
node_feature_dim = 0
for g, _ in dataset:
node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
edge_feature_dim = 0
for g, _ in dataset:
edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
node_feature_dim += 1
edge_feature_dim += 1
full_dataset = [i for i in range(len(dataset))]
train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
return {'dataset': dataset,
'train_index': train_dataset,
'full_index': full_dataset,
'n_feat': node_feature_dim,
'e_feat': edge_feature_dim}
File added
File added
File added
eval.py 0 → 100644
import torch
import warnings
from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata
from model.autoencoder import build_model
from utils.poolers import Pooling
from utils.utils import set_random_seed
import numpy as np
from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
from utils.config import build_args
warnings.filterwarnings('ignore')
def main(main_args):
device = "cpu"
device = torch.device(device)
dataset_name = "trace"
if dataset_name in ['streamspot', 'wget']:
main_args.num_hidden = 256
main_args.num_layers = 4
else:
main_args["num_hidden"] = 64
main_args["num_layers"] = 3
set_random_seed(0)
if dataset_name == 'streamspot' or dataset_name == 'wget':
dataset = load_batch_level_dataset(dataset_name)
n_node_feat = dataset['n_feat']
n_edge_feat = dataset['e_feat']
main_args.n_dim = n_node_feat
main_args.e_dim = n_edge_feat
model = build_model(main_args)
model.load_state_dict(torch.load("./result/FedOpt-{}.pt".format(dataset_name), map_location=device))
model = model.to(device)
pooler = Pooling(main_args.pooling)
test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], args.dataset, main_args.n_dim,
main_args.e_dim)
else:
metadata = load_metadata(dataset_name)
main_args["n_dim"] = metadata['node_feature_dim']
main_args["e_dim"] = metadata['edge_feature_dim']
model = build_model(main_args)
model.load_state_dict(torch.load("./result/checkpoint-{}.pt".format(dataset_name), map_location=device))
model = model.to(device)
model.eval()
malicious, _ = metadata['malicious']
n_train = metadata['n_train']
n_test = metadata['n_test']
with torch.no_grad():
x_train = []
for i in range(n_train):
g = load_entity_level_dataset(dataset_name, 'train', i).to(device)
x_train.append(model.embed(g).cpu().detach().numpy())
del g
x_train = np.concatenate(x_train, axis=0)
skip_benign = 0
x_test = []
for i in range(n_test):
g = load_entity_level_dataset(dataset_name, 'test', i).to(device)
# Exclude training samples from the test set
if i != n_test - 1:
skip_benign += g.number_of_nodes()
x_test.append(model.embed(g).cpu().detach().numpy())
del g
x_test = np.concatenate(x_test, axis=0)
n = x_test.shape[0]
y_test = np.zeros(n)
y_test[malicious] = 1.0
malicious_dict = {}
for i, m in enumerate(malicious):
malicious_dict[m] = i
# Exclude training samples from the test set
test_idx = []
for i in range(x_test.shape[0]):
if i >= skip_benign or y_test[i] == 1.0:
test_idx.append(i)
result_x_test = x_test[test_idx]
result_y_test = y_test[test_idx]
del x_test, y_test
test_auc, test_std, _, _ = evaluate_entity_level_using_knn(dataset_name, x_train, result_x_test,
result_y_test)
print(f"#Test_AUC: {test_auc:.4f}±{test_std:.4f}")
return
if __name__ == '__main__':
args = build_args()
main(args)
File added
common_args:
training_type: "simulation"
random_seed: 0
data_args:
dataset: "wget"
data_cache_dir: ~/fedgraphnn_data/
part_file: ~/fedgraphnn_data/partition
model_args:
model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
global_model_file_path: "./model_file_cache/global_model.pt"
environment_args:
bootstrap: config/bootstrap.sh
train_args:
federated_optimizer: "FedAvg"
client_id_list:
client_num_in_total: 4
client_num_per_round: 4
comm_round: 100
lr: 0.001
server_lr: 0.001
wd: 0.001
ci: 0
server_momentum: 0.9
validation_args:
frequency_of_the_test: 1
device_args:
worker_num: 4
using_gpu: false
gpu_mapping_file: config/gpu_mapping.yaml
gpu_mapping_key: mapping_fedgraphnn_sp
comm_args:
backend: "MPI"
is_mobile: 0
tracking_args:
# When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
enable_wandb: false
wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
wandb_project: fedml
wandb_name: fedml_torch_moleculenet
main.py 0 → 100644
import logging
import fedml
from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split
from fedml import FedMLRunner
from trainer.magic_trainer import MagicTrainer
from trainer.magic_aggregator import MagicWgetAggregator
from model.autoencoder import build_model
from utils.config import build_args
from trainer.magic_trainer import MagicTrainer
from trainer.magic_aggregator import MagicWgetAggregator
from trainer.single_trainer import train_single
from utils.loaddata import load_batch_level_dataset, load_metadata
def generate_dataset(name, number):
(
train_data_num,
val_data_num,
test_data_num,
train_data_global,
val_data_global,
test_data_global,
data_local_num_dict,
train_data_local_dict,
val_data_local_dict,
test_data_local_dict,
) = load_partition_data(None, number, name)
dataset = [
train_data_num,
test_data_num,
train_data_global,
test_data_global,
data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
len(train_data_global),
]
if (name == "wget" or name == "streamspot"):
return dataset, load_batch_level_dataset(name)
else:
return dataset, load_metadata(name)
if __name__ == "__main__":
# init FedML framework
args = fedml.init()
# init device
device = fedml.device.get_device(args)
name = args.dataset
number = args.client_num_in_total
dataset, metadata = generate_dataset(name, number)
main_args = build_args()
if (name == "wget"):
main_args["num_hidden"] = 256
main_args["max_epoch"] = 2
main_args["num_layers"] = 4
n_node_feat = metadata['n_feat']
n_edge_feat = metadata['e_feat']
main_args["n_dim"] = n_node_feat
main_args["e_dim"] = n_edge_feat
elif (name == "streamspot"):
main_args["num_hidden"] = 256
main_args["max_epoch"] = 5
main_args["num_layers"] = 4
n_node_feat = metadata['n_feat']
n_edge_feat = metadata['e_feat']
main_args["n_dim"] = n_node_feat
main_args["e_dim"] = n_edge_feat
else:
main_args["num_hidden"] = 64
main_args["max_epoch"] = 50
main_args["num_layers"] = 3
main_args["n_dim"] = metadata["node_feature_dim"]
main_args["e_dim"] = metadata["edge_feature_dim"]
model = build_model(main_args)
#train_single(main_args, model, data)
trainer = MagicTrainer(model, args, name)
aggregator = MagicWgetAggregator(model, args, name)
fedml_runner = FedMLRunner(args, device, dataset, model, trainer, aggregator)
fedml_runner.run()
# start training
#darpa_split("theia")
File added
File added
File added
File added
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment