diff --git a/README.md b/README.md index 18d22331f41c4693d1558938b3e7471b4468afe2..22b75bfacc1affcdc8d9e94144606c1fc1eadc6e 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,15 @@ -# FEDHE-Graph +# CONTINUUM-FEDHE-Graph -Welcome to the official repository housing the FEDHE-Graph implementation for Magic! This repository provides you with the necessary tools and resources to leverage federated learning techniques within the context of Magic, a comprehensive framework for federated learning research. +Welcome to the official repository housing the FEDHE-Graph implementation for our solution Continuum! This repository provides you with the necessary tools and resources to leverage federated learning techniques within the context of Continuum, a comprehensive framework for federated learning research. - + + - -Original project: https://github.com/FDUDSDE/MAGIC - ## Environment Setup -The command are used in an environnement that consist of Ubuntu 22.04 with miniconda installed +The command are used in an environnement that consist of Windows 11 with anaconda installed -Original project: https://github.com/FDUDSDE/MAGIC First create the conda environnement for fedml with MPI support @@ -24,13 +21,13 @@ conda install -c conda-forge mpi4py openmpi pip install "fedml[MPI]" ``` -Clone the MAGIC FedML project onto your current folder +Clone the Continuum FedML project onto your current folder ``` -git clone https://github.com/kamelferrahi/MAGIC_FEDERATED_FedML +git clone https://github.com/kamelferrahi/[MAGIC_FEDERATED_FedML](https://github.com/kamelferrahi/Continuum_FL) ``` -Install the necessary packages for Magic to run +Install the necessary packages for Continuum to run ``` conda install -c conda-forge aiohttp=3.9.1 aiosignal=1.3.1 anyio=4.2.0 attrdict=2.0.1 attrs=23.2.0 blis=0.7.11 boto3=1.34.12 botocore=1.34.12 brotli=1.1.0 catalogue=2.0.10 certifi=2023.11.17 chardet=5.2.0 charset-normalizer=3.3.2 click=8.1.7 cloudpathlib=0.16.0 confection=0.1.4 contourpy=1.2.0 cycler=0.12.1 cymem=2.0.8 dgl=1.1.3 dill=0.3.7 fastapi=0.92.0 fedml=0.8.13.post2 filelock=3.13.1 fonttools=4.47.0 frozenlist=1.4.1 fsspec=2023.12.2 gensim=4.3.2 gevent=23.9.1 geventhttpclient=2.0.9 gitdb=4.0.11 GitPython=3.1.40 GPUtil=1.4.0 graphviz=0.8.4 greenlet=3.0.3 h11=0.14.0 h5py=3.10.0 httpcore=1.0.2 httpx=0.26.0 idna=3.6 Jinja2=3.1.2 jmespath=1.0.1 joblib=1.3.2 kiwisolver=1.4.5 langcodes=3.3.0 MarkupSafe=2.1.3 matplotlib=3.8.2 mpi4py=3.1.3 mpmath=1.3.0 multidict=6.0.4 multiprocess=0.70.15 murmurhash=1.0.10 networkx=2.8.8 ntplib=0.4.0 numpy=1.26.3 nvidia-cublas-cu12=12.1.3.1 nvidia-cuda-cupti-cu12=12.1.105 nvidia-cuda-nvrtc-cu12=12.1.105 nvidia-cuda-runtime-cu12=12.1.105 nvidia-cudnn-cu12=8.9.2.26 nvidia-cufft-cu12=11.0.2.54 nvidia-curand-cu12=10.3.2.106 nvidia-cusolver-cu12=11.4.5.107 nvidia-cusparse-cu12=12.1.0.106 nvidia-nccl-cu12=2.18.1 nvidia-nvjitlink-cu12=12.3.101 nvidia-nvtx-cu12=12.1.105 onnx=1.15.0 packaging=23.2 paho-mqtt=1.6.1 pandas=2.1.4 pathtools=0.1.2 pillow=10.2.0 preshed=3.0.9 prettytable=3.9.0 promise=2.3 protobuf=3.20.3 psutil=5.9.7 py-machineid=0.4.6 pydantic=1.10.13 pyparsing=3.1.1 python-dateutil=2.8.2 python-rapidjson=1.14 pytz=2023.3.post1 PyYAML=6.0.1 redis=5.0.1 requests=2.31.0 s3transfer=0.10.0 scikit-learn=1.3.2 scipy=1.11.4 sentry-sdk=1.39.1 setproctitle=1.3.3 shortuuid=1.0.11 six=1.16.0 smart-open=6.3.0 smmap=5.0.1 sniffio=1.3.0 spacy=3.7.2 spacy-legacy=3.0.12 spacy-loggers=1.0.5 SQLAlchemy=2.0.25 srsly=2.4.8 starlette=0.25.0 sympy=1.12 thinc=8.2.2 threadpoolctl=3.2.0 torch=2.1.2 torch-cluster=1.6.3 torch-scatter=2.1.2 torch-sparse=0.6.18 torch-spline-conv=1.2.2 torch_geometric=2.4.0 torchvision=0.16.2 tqdm=4.66.1 triton=2.1.0 tritonclient=2.41.0 typer=0.9.0 typing_extensions=4.9.0 tzdata=2023.4 tzlocal=5.2 urllib3=2.0.7 uvicorn=0.25.0 wandb=0.13.2 wasabi=1.1.2 wcwidth=0.2.12 weasel=0.3.4 websocket-client=1.7.0 wget=3.2 yarl=1.9.4 zope.event=5.0 zope.interface=6.1 @@ -57,7 +54,7 @@ train_args: The algorithm tested are `FedAvg`, `FedProx` and `FedOpt` ## Datasets -The experiments utilize datasets similar to those in the original Magic project. To change datasets, edit the `fedml_config.yaml` file: +The experiments utilize datasets similar to those in the original Continuum project. To change datasets, edit the `fedml_config.yaml` file: ``` data_args: dataset: "wget" diff --git a/a.yaml b/a.yaml deleted file mode 100644 index 9f8ce6e7fa889de842853b12c3a723d56ab60899..0000000000000000000000000000000000000000 --- a/a.yaml +++ /dev/null @@ -1,83 +0,0 @@ -common_args: - training_type: "cross_silo" - scenario: "horizontal" - using_mlops: false - config_version: release - name: "exp" - project: "runs/train" - exist_ok: false - random_seed: 0 - - common_args: - training_type: "simulation" - random_seed: 0 -comm_args: - backend: "MPI" - is_mobile: 0 - - -data_args: - dataset: "clintox" - data_cache_dir: ~/fedgraphnn_data/ - part_file: ~/fedgraphnn_data/partition - partition_method: "hetero" - partition_alpha: 0.5 - -model_args: - model: "graphsage" - hidden_size: 32 - node_embedding_dim: 32 - graph_embedding_dim: 64 - readout_hidden_dim: 64 - alpha: 0.2 - num_heads: 2 - dropout: 0.3 - normalize_features: False - normalize_adjacency: False - sparse_adjacency: False - model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically - global_model_file_path: "./model_file_cache/global_model.pt" - -environment_args: - bootstrap: config/bootstrap.sh - -train_args: - federated_optimizer: "FedAvg" - client_id_list: - client_num_in_total: 3 - client_num_per_round: 3 - comm_round: 100 - epochs: 5 - batch_size: 64 - client_optimizer: sgd - learning_rate: 0.03 - weight_decay: 0.001 - metric: "prc-auc" - server_optimizer: sgd - lr: 0.001 - server_lr: 0.001 - wd: 0.001 - ci: 0 - server_momentum: 0.9 - -validation_args: - frequency_of_the_test: 1 - -device_args: - worker_num: 3 - using_gpu: false - gpu_mapping_file: config/gpu_mapping.yaml - gpu_mapping_key: mapping_fedgraphnn_sp - -comm_args: - backend: "MQTT_S3" - mqtt_config_path: config/mqtt_config.yaml - s3_config_path: config/s3_config.yaml - - -tracking_args: - # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/ - enable_wandb: false - wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 - wandb_project: fedml - wandb_name: fedml_torch_moleculenet diff --git a/a.yml b/a.yml deleted file mode 100644 index adfbd8733fcee48a46e87a24cc48cbb1a0241f87..0000000000000000000000000000000000000000 --- a/a.yml +++ /dev/null @@ -1,68 +0,0 @@ -common_args: - training_type: "simulation" - random_seed: 0 - -data_args: - dataset: "clintox" - data_cache_dir: ~/fedgraphnn_data/ - part_file: ~/fedgraphnn_data/partition - partition_method: "hetero" - partition_alpha: 0.5 - -model_args: - model: "graphsage" - hidden_size: 32 - node_embedding_dim: 32 - graph_embedding_dim: 64 - readout_hidden_dim: 64 - alpha: 0.2 - num_heads: 2 - dropout: 0.3 - normalize_features: False - normalize_adjacency: False - sparse_adjacency: False - model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically - global_model_file_path: "./model_file_cache/global_model.pt" - -environment_args: - bootstrap: config/bootstrap.sh - -train_args: - federated_optimizer: "FedAvg" - client_id_list: - client_num_in_total: 3 - client_num_per_round: 3 - comm_round: 100 - epochs: 5 - batch_size: 64 - client_optimizer: sgd - learning_rate: 0.03 - weight_decay: 0.001 - metric: "prc-auc" - server_optimizer: sgd - lr: 0.001 - server_lr: 0.001 - wd: 0.001 - ci: 0 - server_momentum: 0.9 - -validation_args: - frequency_of_the_test: 1 - -device_args: - worker_num: 2 - using_gpu: false - gpu_mapping_file: config/gpu_mapping.yaml - gpu_mapping_key: mapping_fedgraphnn_sp - -comm_args: - backend: "MPI" - is_mobile: 0 - - -tracking_args: - # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/ - enable_wandb: false - wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 - wandb_project: fedml - wandb_name: fedml_torch_moleculenet diff --git a/assets/.gitkeep b/assets/.gitkeep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/assets/archiFedHe.png b/assets/archiFedHe.png deleted file mode 100644 index 8495f1456bb9251d2022b52c605efaddbebd447b..0000000000000000000000000000000000000000 Binary files a/assets/archiFedHe.png and /dev/null differ diff --git a/checkpoints - Copie/checkpoint-SC2.pt b/checkpoints - Copie/checkpoint-SC2.pt new file mode 100644 index 0000000000000000000000000000000000000000..963a001f792613ee2fb2f5f22cb25b599a8dedbb Binary files /dev/null and b/checkpoints - Copie/checkpoint-SC2.pt differ diff --git a/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt b/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt new file mode 100644 index 0000000000000000000000000000000000000000..20073e1503d9f57b75c278bdb89c757d3aaccb30 Binary files /dev/null and b/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt differ diff --git a/checkpoints - Copie/checkpoint-cadets-e3.pt b/checkpoints - Copie/checkpoint-cadets-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..ea7ee25689066513d46aea55e07f71df859ec8c7 Binary files /dev/null and b/checkpoints - Copie/checkpoint-cadets-e3.pt differ diff --git a/checkpoints - Copie/checkpoint-clearscope-e3.pt b/checkpoints - Copie/checkpoint-clearscope-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..5edaf9d73bdde92067b7e2f6f0129a4673feb945 Binary files /dev/null and b/checkpoints - Copie/checkpoint-clearscope-e3.pt differ diff --git a/checkpoints - Copie/checkpoint-streamspot.pt b/checkpoints - Copie/checkpoint-streamspot.pt new file mode 100644 index 0000000000000000000000000000000000000000..ec1ed84e3dc16589af58039e87a34f21cc579cd5 Binary files /dev/null and b/checkpoints - Copie/checkpoint-streamspot.pt differ diff --git a/checkpoints - Copie/checkpoint-theia-e3.pt b/checkpoints - Copie/checkpoint-theia-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..1513228c205f790835edb723dffe967eec960732 Binary files /dev/null and b/checkpoints - Copie/checkpoint-theia-e3.pt differ diff --git a/checkpoints - Copie/checkpoint-trace-e3.pt b/checkpoints - Copie/checkpoint-trace-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..755598d71f01bee477182c9fe454b4646ac7ffcd Binary files /dev/null and b/checkpoints - Copie/checkpoint-trace-e3.pt differ diff --git a/checkpoints - Copie/checkpoint-wget-long.pt b/checkpoints - Copie/checkpoint-wget-long.pt new file mode 100644 index 0000000000000000000000000000000000000000..f5c8f96c5be068033746e3b9271b983c438bee54 Binary files /dev/null and b/checkpoints - Copie/checkpoint-wget-long.pt differ diff --git a/checkpoints - Copie/checkpoint-wget.pt b/checkpoints - Copie/checkpoint-wget.pt new file mode 100644 index 0000000000000000000000000000000000000000..1d5f6c0071b7c0d0e136aa6112ade76c083984fc Binary files /dev/null and b/checkpoints - Copie/checkpoint-wget.pt differ diff --git a/checkpoints/checkpoint-SC2.pt b/checkpoints/checkpoint-SC2.pt new file mode 100644 index 0000000000000000000000000000000000000000..aec48a89437ad5ef4c3199b0547a3efc9365115f Binary files /dev/null and b/checkpoints/checkpoint-SC2.pt differ diff --git a/checkpoints/checkpoint-Unicorn-Cadets.pt b/checkpoints/checkpoint-Unicorn-Cadets.pt new file mode 100644 index 0000000000000000000000000000000000000000..90d5dcfbb036da54972648625073287bad1b6986 Binary files /dev/null and b/checkpoints/checkpoint-Unicorn-Cadets.pt differ diff --git a/checkpoints/checkpoint-cadets-e3 - Copie (2).pt b/checkpoints/checkpoint-cadets-e3 - Copie (2).pt new file mode 100644 index 0000000000000000000000000000000000000000..5f62ae9abed4280d950ebf431a30b2f883b386a2 Binary files /dev/null and b/checkpoints/checkpoint-cadets-e3 - Copie (2).pt differ diff --git a/checkpoints/checkpoint-cadets-e3.pt b/checkpoints/checkpoint-cadets-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..157f6b404a8228c9da2483564ab892e0fd4f8a88 Binary files /dev/null and b/checkpoints/checkpoint-cadets-e3.pt differ diff --git a/checkpoints/checkpoint-clearscope-e3.pt b/checkpoints/checkpoint-clearscope-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..0ebdb975ec2e7455daed67ab002f1b39a9984c1b Binary files /dev/null and b/checkpoints/checkpoint-clearscope-e3.pt differ diff --git a/checkpoints/checkpoint-streamspot.pt b/checkpoints/checkpoint-streamspot.pt new file mode 100644 index 0000000000000000000000000000000000000000..0065a7b8b52665153af952617ab45cde61fe14eb Binary files /dev/null and b/checkpoints/checkpoint-streamspot.pt differ diff --git a/checkpoints/checkpoint-theia-e3.pt b/checkpoints/checkpoint-theia-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..1f0fcc66e18b3e6d534f99e090e5387807451a64 Binary files /dev/null and b/checkpoints/checkpoint-theia-e3.pt differ diff --git a/checkpoints/checkpoint-trace-e3.pt b/checkpoints/checkpoint-trace-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..f050065e56c009605562ca6e980652b6c2a33938 Binary files /dev/null and b/checkpoints/checkpoint-trace-e3.pt differ diff --git a/checkpoints/checkpoint-wget-long.pt b/checkpoints/checkpoint-wget-long.pt new file mode 100644 index 0000000000000000000000000000000000000000..040897b99810c72f6bb6d2314aab468b3972de3e Binary files /dev/null and b/checkpoints/checkpoint-wget-long.pt differ diff --git a/checkpoints/checkpoint-wget.pt b/checkpoints/checkpoint-wget.pt new file mode 100644 index 0000000000000000000000000000000000000000..8404279998dc213fba6b6f626d55c37ccb7c8ca9 Binary files /dev/null and b/checkpoints/checkpoint-wget.pt differ diff --git a/distance_save_cadets.pkl b/distance_save_cadets.pkl deleted file mode 100644 index 1f2f428d6e2e6ffb8afbfd0d0669c46cab8718c6..0000000000000000000000000000000000000000 Binary files a/distance_save_cadets.pkl and /dev/null differ diff --git a/eval_result/.gitkeep b/eval_result/.gitkeep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/eval_result/distance_save_cadets-e3 - Copie.pkl b/eval_result/distance_save_cadets-e3 - Copie.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7f737190e44abb34ebea3e6763c64b9f1ae3a89a Binary files /dev/null and b/eval_result/distance_save_cadets-e3 - Copie.pkl differ diff --git a/eval_result/distance_save_cadets-e3.pkl b/eval_result/distance_save_cadets-e3.pkl new file mode 100644 index 0000000000000000000000000000000000000000..10c18f8c50bfce89a1afd0811cc24ce8e0595786 Binary files /dev/null and b/eval_result/distance_save_cadets-e3.pkl differ diff --git a/eval_result/distance_save_cadets.pkl b/eval_result/distance_save_cadets.pkl deleted file mode 100644 index 2bb4de8e00a64a22e53b1a5187821e6e53dfffcc..0000000000000000000000000000000000000000 Binary files a/eval_result/distance_save_cadets.pkl and /dev/null differ diff --git a/eval_result/distance_save_theia-e3.pkl b/eval_result/distance_save_theia-e3.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c5d31bab1acc6e943c746a154eec3a9370ccb644 Binary files /dev/null and b/eval_result/distance_save_theia-e3.pkl differ diff --git a/eval_result/distance_save_trace-e3 - Copie.pkl b/eval_result/distance_save_trace-e3 - Copie.pkl new file mode 100644 index 0000000000000000000000000000000000000000..709b4f83bc0d8cd4d8be175fed1c54d366c9de16 Binary files /dev/null and b/eval_result/distance_save_trace-e3 - Copie.pkl differ diff --git a/eval_result/distance_save_trace-e3.pkl b/eval_result/distance_save_trace-e3.pkl new file mode 100644 index 0000000000000000000000000000000000000000..709b4f83bc0d8cd4d8be175fed1c54d366c9de16 Binary files /dev/null and b/eval_result/distance_save_trace-e3.pkl differ diff --git a/fedml_config.yaml b/fedml_config.yaml index 66714a20c7a2b471676dac8be0801b3ae132e5cf..40cd616a925f3514dda97dd8ba1b20aaead81744 100644 --- a/fedml_config.yaml +++ b/fedml_config.yaml @@ -1,49 +1,53 @@ common_args: - training_type: "simulation" + training_type: "cross_silo" + scenario: "horizontal" + using_mlops: false + config_version: release + name: "exp" + project: "runs/train" + exist_ok: false random_seed: 0 data_args: - dataset: "wget" - data_cache_dir: ~/fedgraphnn_data/ - part_file: ~/fedgraphnn_data/partition + dataset: "trace-e3" model_args: model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically global_model_file_path: "./model_file_cache/global_model.pt" -environment_args: - bootstrap: config/bootstrap.sh train_args: federated_optimizer: "FedAvg" client_id_list: - client_num_in_total: 4 - client_num_per_round: 4 - comm_round: 100 - lr: 0.001 - server_lr: 0.001 - wd: 0.001 - ci: 0 - server_momentum: 0.9 + client_num_in_total: 2 + client_num_per_round: 2 + comm_round: 1 + snapshot: 1 validation_args: frequency_of_the_test: 1 device_args: - worker_num: 4 - using_gpu: false - gpu_mapping_file: config/gpu_mapping.yaml - gpu_mapping_key: mapping_fedgraphnn_sp + worker_num: 2 + using_gpu: true + gpu_mapping_file: gpu_mapping.yaml + gpu_mapping_key: mapping_config comm_args: - backend: "MPI" - is_mobile: 0 - + backend: "MQTT_S3" + mqtt_config_path: config/mqtt_config.yaml tracking_args: # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/ enable_wandb: false wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 wandb_project: fedml - wandb_name: fedml_torch_moleculenet + wandb_name: fedml_torch + +# fhe_args: +# # enable_fhe: true + # scheme: ckks +# batch_size: 8192 +# scaling_factor: 52 +# file_loc: "resources/cryptoparams/" \ No newline at end of file diff --git a/gpu_mapping.yaml b/gpu_mapping.yaml new file mode 100644 index 0000000000000000000000000000000000000000..90b13bfe6abfa84b56845ac220a5c67ced781cbe --- /dev/null +++ b/gpu_mapping.yaml @@ -0,0 +1,2 @@ +mapping_config: + host1: [3] \ No newline at end of file diff --git a/main.py b/main.py index a1a716e17f69f73d33fa62ef5b3da1e85b90b715..7d7430d7fb88d4259d8096266633bb5dbd94bc8f 100644 --- a/main.py +++ b/main.py @@ -1,20 +1,18 @@ import logging import fedml -from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split +from utils.dataloader import load_partition_data, load_data, load_metadata, darpa_split from fedml import FedMLRunner from trainer.magic_trainer import MagicTrainer from trainer.magic_aggregator import MagicWgetAggregator -from model.autoencoder import build_model +from model.model import STGNN_AutoEncoder from utils.config import build_args from trainer.magic_trainer import MagicTrainer from trainer.magic_aggregator import MagicWgetAggregator -from trainer.single_trainer import train_single -from utils.loaddata import load_batch_level_dataset, load_metadata -def generate_dataset(name, number): +def generate_dataset(name, number, nsnapshot): ( train_data_num, val_data_num, @@ -26,7 +24,7 @@ def generate_dataset(name, number): train_data_local_dict, val_data_local_dict, test_data_local_dict, - ) = load_partition_data(None, number, name) + ) = load_partition_data(number, name, nsnapshot) dataset = [ train_data_num, test_data_num, @@ -38,9 +36,9 @@ def generate_dataset(name, number): len(train_data_global), ] - if (name == "wget" or name == "streamspot"): + if (name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']): - return dataset, load_batch_level_dataset(name) + return dataset, load_data(name) else: return dataset, load_metadata(name) @@ -49,39 +47,47 @@ if __name__ == "__main__": # init FedML framework args = fedml.init() # init device + device = fedml.device.get_device(args) - name = args.dataset + dataset_name = args.dataset number = args.client_num_in_total - - dataset, metadata = generate_dataset(name, number) + nsnapshot = args.snapshot + dataset, metadata = generate_dataset(dataset_name, number, nsnapshot) main_args = build_args() - if (name == "wget"): - main_args["num_hidden"] = 256 - main_args["max_epoch"] = 2 - main_args["num_layers"] = 4 + if (dataset_name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']): + + main_args.max_epoch = 6 + out_dim = 64 + if (dataset_name == 'SC2'): + gnn_layer = 3 + else: + gnn_layer = 5 n_node_feat = metadata['n_feat'] n_edge_feat = metadata['e_feat'] - main_args["n_dim"] = n_node_feat - main_args["e_dim"] = n_edge_feat - elif (name == "streamspot"): - main_args["num_hidden"] = 256 - main_args["max_epoch"] = 5 - main_args["num_layers"] = 4 - n_node_feat = metadata['n_feat'] - n_edge_feat = metadata['e_feat'] - main_args["n_dim"] = n_node_feat - main_args["e_dim"] = n_edge_feat - else: - main_args["num_hidden"] = 64 - main_args["max_epoch"] = 50 - main_args["num_layers"] = 3 - main_args["n_dim"] = metadata["node_feature_dim"] - main_args["e_dim"] = metadata["edge_feature_dim"] - - model = build_model(main_args) + use_all_hidden = True + main_args.n_dim = n_node_feat + main_args.e_dim = n_edge_feat + else: + use_all_hidden = False + n_node_feat = metadata['node_feature_dim'] + n_edge_feat = metadata['edge_feature_dim'] + #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148] + main_args.n_dim = n_node_feat + main_args.e_dim = n_edge_feat + main_args.max_epoch = 50 + out_dim = 64 + + if (dataset_name == 'cadets-e3'): + gnn_layer = 4 + else: + gnn_layer = 3 + + + + model = STGNN_AutoEncoder(main_args.n_dim, main_args.e_dim, out_dim, out_dim, gnn_layer, 4, device, nsnapshot, 'prelu', 0.1, main_args.negative_slope, True, 'BatchNorm', main_args.pooling, alpha_l=main_args.alpha_l, use_all_hidden=use_all_hidden).to(device) # Move model to GPU #train_single(main_args, model, data) - trainer = MagicTrainer(model, args, name) - aggregator = MagicWgetAggregator(model, args, name) + trainer = MagicTrainer(model, args, dataset_name) + aggregator = MagicWgetAggregator(model, args, dataset_name) fedml_runner = FedMLRunner(args, device, dataset, model, trainer, aggregator) fedml_runner.run() # start training diff --git a/model/__pycache__/autoencoder.cpython-311.pyc b/model/__pycache__/autoencoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3abc72625245a57fa06470e3659c594a0c033046 Binary files /dev/null and b/model/__pycache__/autoencoder.cpython-311.pyc differ diff --git a/model/__pycache__/eval.cpython-310.pyc b/model/__pycache__/eval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dc58b1214be9d335558c42907cf1feb18b81969 Binary files /dev/null and b/model/__pycache__/eval.cpython-310.pyc differ diff --git a/model/__pycache__/eval.cpython-311.pyc b/model/__pycache__/eval.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faade8397a5b3c546b35f34aef088d8992486510 Binary files /dev/null and b/model/__pycache__/eval.cpython-311.pyc differ diff --git a/model/__pycache__/gat.cpython-310.pyc b/model/__pycache__/gat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54fa4abdd4680faa41fbf0e19697fab2f000402b Binary files /dev/null and b/model/__pycache__/gat.cpython-310.pyc differ diff --git a/model/__pycache__/gat.cpython-311.pyc b/model/__pycache__/gat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94e6046f4f22c6a2f271c2ca3bded5d8f2417cfa Binary files /dev/null and b/model/__pycache__/gat.cpython-311.pyc differ diff --git a/model/__pycache__/loss_func.cpython-310.pyc b/model/__pycache__/loss_func.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79ac15810418709637606b15c5c43b4b96aec691 Binary files /dev/null and b/model/__pycache__/loss_func.cpython-310.pyc differ diff --git a/model/__pycache__/loss_func.cpython-311.pyc b/model/__pycache__/loss_func.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40b3c9d58012a8600cda16f7d39ee5c5891fa862 Binary files /dev/null and b/model/__pycache__/loss_func.cpython-311.pyc differ diff --git a/model/__pycache__/model.cpython-310.pyc b/model/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3984a3444fb13f56ea0c2629e2b0883298bf8f4d Binary files /dev/null and b/model/__pycache__/model.cpython-310.pyc differ diff --git a/model/__pycache__/model.cpython-311.pyc b/model/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9a27f13ee6839a64257d3125b777544a866f329 Binary files /dev/null and b/model/__pycache__/model.cpython-311.pyc differ diff --git a/model/__pycache__/rnn.cpython-310.pyc b/model/__pycache__/rnn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d02d9b442912cda476cacbc18673eebe9e000da7 Binary files /dev/null and b/model/__pycache__/rnn.cpython-310.pyc differ diff --git a/model/__pycache__/rnn.cpython-311.pyc b/model/__pycache__/rnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d71283ae5cbfc01ec9eb9f1448d8e8b292d1f08 Binary files /dev/null and b/model/__pycache__/rnn.cpython-311.pyc differ diff --git a/model/__pycache__/test.cpython-310.pyc b/model/__pycache__/test.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d85be9aa8853ef325090a1b1cd0f366d3224792 Binary files /dev/null and b/model/__pycache__/test.cpython-310.pyc differ diff --git a/model/__pycache__/test.cpython-311.pyc b/model/__pycache__/test.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d7beba377fb059135e3135b292a910477636c7c Binary files /dev/null and b/model/__pycache__/test.cpython-311.pyc differ diff --git a/model/__pycache__/train.cpython-311.pyc b/model/__pycache__/train.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b788a354802ea312364f91691db5b00e502922a6 Binary files /dev/null and b/model/__pycache__/train.cpython-311.pyc differ diff --git a/model/__pycache__/train_entity.cpython-310.pyc b/model/__pycache__/train_entity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbe4cebe4c9d3a09cdc4a8b78f976f98a23baef1 Binary files /dev/null and b/model/__pycache__/train_entity.cpython-310.pyc differ diff --git a/model/__pycache__/train_entity.cpython-311.pyc b/model/__pycache__/train_entity.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf1b64f3a76e749403864eb4a4a10154a8003a84 Binary files /dev/null and b/model/__pycache__/train_entity.cpython-311.pyc differ diff --git a/model/__pycache__/train_graph.cpython-310.pyc b/model/__pycache__/train_graph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf931f350c1de50fb725f95c7ebc73f1c6facc66 Binary files /dev/null and b/model/__pycache__/train_graph.cpython-310.pyc differ diff --git a/model/__pycache__/train_graph.cpython-311.pyc b/model/__pycache__/train_graph.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d98093beaf11c1ffeb71b9abc556e33503934038 Binary files /dev/null and b/model/__pycache__/train_graph.cpython-311.pyc differ diff --git a/model/eval.py b/model/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7446db31e19a2e144eb061e0526cb44e23798935 --- /dev/null +++ b/model/eval.py @@ -0,0 +1,270 @@ +import os +import random +import time +import pickle as pkl +import torch +import numpy as np +from sklearn.metrics import roc_auc_score, precision_recall_curve +from sklearn.neighbors import NearestNeighbors +from utils.utils import set_random_seed +from utils.dataloader import load_graph, load_data +import pickle as pkl + +def batch_level_evaluation(model, pooler, device, method, dataset, n_dim=0, e_dim=0): + print('Start Evaluation') + + model.eval() + x_list = [] + y_list = [] + data = load_data(dataset) + full = data['full_index'] + labels = data['labels'] + with torch.no_grad(): + for i in full: + #break + g = load_graph(i, dataset, device) + label = labels[i] + out = model.embed(g) + if dataset != 'wget': + out = pooler(g[-1], out).cpu().numpy() + else: + out = pooler(g[-1], out, [1]).cpu().numpy() + y_list.append(label) + x_list.append(out) + + #pkl.dump(x_list,open('xlist.pkl','wb') ) + #pkl.dump(y_list,open('ylist.pkl','wb') ) + #x_list = pkl.load(open('xlist.pkl','rb')) + #y_list = pkl.load(open('ylist.pkl','rb')) + x = np.concatenate(x_list, axis=0) + y = np.array(y_list) + if 'knn' in method: + test_auc, test_std = evaluate_batch_level_using_knn(-1, dataset, x, y) + else: + raise NotImplementedError + return test_auc, test_std + + +def evaluate_batch_level_using_knn(repeat, dataset, embeddings, labels): + x, y = embeddings, labels + if dataset == 'streamspot': + train_count = 400 + elif (dataset == 'Unicorn-Cadets' or dataset == 'wget-long'): + train_count = 70 + elif (dataset == 'wget' or dataset == 'SC2'): + train_count = 100 + else: + train_count = 30 + + if (dataset =='SC2'): + n_neighbors = min(int(train_count * 0.02), 10) + else: + n_neighbors = 100 + + benign_idx = np.where(y == 0)[0] + attack_idx = np.where(y == 1)[0] + if repeat != -1: + prec_list = [] + rec_list = [] + f1_list = [] + tp_list = [] + fp_list = [] + tn_list = [] + fn_list = [] + auc_list = [] + for s in range(repeat): + set_random_seed(s) + np.random.shuffle(benign_idx) + np.random.shuffle(attack_idx) + x_train = x[benign_idx[:train_count]] + x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0) + y_test = np.concatenate([y[benign_idx[train_count:]], y[attack_idx]], axis=0) + x_train_mean = x_train.mean(axis=0) + x_train_std = x_train.std(axis=0) + x_train = (x_train - x_train_mean) / x_train_std + x_test = (x_test - x_train_mean) / x_train_std + + nbrs = NearestNeighbors(n_neighbors=n_neighbors) + nbrs.fit(x_train) + distances, indexes = nbrs.kneighbors(x_train, n_neighbors=n_neighbors) + mean_distance = distances.mean() * n_neighbors / (n_neighbors - 1) + distances, indexes = nbrs.kneighbors(x_test, n_neighbors=n_neighbors) + + score = distances.mean(axis=1) / mean_distance + + auc = roc_auc_score(y_test, score) + prec, rec, threshold = precision_recall_curve(y_test, score) + f1 = 2 * prec * rec / (rec + prec + 1e-9) + max_f1_idx = np.argmax(f1) + best_thres = threshold[max_f1_idx] + prec_list.append(prec[max_f1_idx]) + rec_list.append(rec[max_f1_idx]) + f1_list.append(f1[max_f1_idx]) + + tn = 0 + fn = 0 + tp = 0 + fp = 0 + for i in range(len(y_test)): + if y_test[i] == 1.0 and score[i] >= best_thres: + tp += 1 + if y_test[i] == 1.0 and score[i] < best_thres: + fn += 1 + if y_test[i] == 0.0 and score[i] < best_thres: + tn += 1 + if y_test[i] == 0.0 and score[i] >= best_thres: + fp += 1 + tp_list.append(tp) + fp_list.append(fp) + fn_list.append(fn) + tn_list.append(tn) + auc_list.append(auc) + + print('AUC: {}+{}'.format(np.mean(auc_list), np.std(auc_list))) + print('F1: {}+{}'.format(np.mean(f1_list), np.std(f1_list))) + print('PRECISION: {}+{}'.format(np.mean(prec_list), np.std(prec_list))) + print('RECALL: {}+{}'.format(np.mean(rec_list), np.std(rec_list))) + print('TN: {}+{}'.format(np.mean(tn_list), np.std(tn_list))) + print('FN: {}+{}'.format(np.mean(fn_list), np.std(fn_list))) + print('TP: {}+{}'.format(np.mean(tp_list), np.std(tp_list))) + print('FP: {}+{}'.format(np.mean(fp_list), np.std(fp_list))) + return np.mean(auc_list), np.std(auc_list) + else: + set_random_seed(2022) + np.random.shuffle(benign_idx) + np.random.shuffle(attack_idx) + x_train = x[benign_idx[:train_count]] + #x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0) + x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0) + y_test = np.concatenate([y[benign_idx[train_count:]], y[attack_idx]], axis=0) + x_train_mean = x_train.mean(axis=0) + x_train_std = x_train.std(axis=0) + x_train = (x_train - x_train_mean) / x_train_std + x_test = (x_test - x_train_mean) / x_train_std + f1_max = 0 + for n_neighbors in range(1, train_count): + nbrs = NearestNeighbors(n_neighbors=n_neighbors) + nbrs.fit(x_train) + distances, indexes = nbrs.kneighbors(x_train, n_neighbors=n_neighbors) + mean_distance = distances.mean() * n_neighbors / (n_neighbors - 1) + #mean_distance = 0.1 + distances, indexes = nbrs.kneighbors(x_test, n_neighbors=n_neighbors) + + score = distances.mean(axis=1) / mean_distance + auc = roc_auc_score(y_test, score) + prec, rec, threshold = precision_recall_curve(y_test, score) + f1 = 2 * prec * rec / (rec + prec + 1e-9) + best_idx = np.argmax(f1) + best_thres = threshold[best_idx] + + tn = 0 + fn = 0 + tp = 0 + fp = 0 + + for i in range(len(y_test)): + if y_test[i] == 1.0 and score[i] >= best_thres: + tp += 1 + if y_test[i] == 1.0 and score[i] < best_thres: + fn += 1 + if y_test[i] == 0.0 and score[i] < best_thres: + + tn += 1 + if y_test[i] == 0.0 and score[i] >= best_thres: + fp += 1 + + if (f1[best_idx]> f1_max): + f1_max = f1[best_idx] + auc_max = auc + prec_max = prec[best_idx] + rec_max = rec[best_idx] + tn_max = tn + fn_max = fn + tp_max = tp + fp_max = fp + best_n = n_neighbors + + print('AUC: {}'.format(auc_max)) + print('F1: {}'.format(f1_max)) + print('PRECISION: {}'.format(prec_max)) + print('RECALL: {}'.format(rec_max)) + print('TN: {}'.format(tn_max)) + print('FN: {}'.format(fn_max)) + print('TP: {}'.format(tp_max)) + print('FP: {}'.format(fp_max)) + print(best_n) + return auc, 0.0 + + +def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test): + x_train_mean = x_train.mean(axis=0) + x_train_std = x_train.std(axis=0) + x_train = (x_train - x_train_mean) / x_train_std + x_test = (x_test - x_train_mean) / x_train_std + + if dataset == 'cadets-e3': + n_neighbors = 200 + else: + n_neighbors = 10 + + nbrs = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=-1) + nbrs.fit(x_train) + + save_dict_path = './eval_result/distance_save_{}.pkl'.format(dataset) + if not os.path.exists(save_dict_path): + idx = list(range(x_train.shape[0])) + random.shuffle(idx) + distances, _ = nbrs.kneighbors(x_train[idx][:min(50000, x_train.shape[0])], n_neighbors=n_neighbors) + del x_train + mean_distance = distances.mean() + del distances + distances, _ = nbrs.kneighbors(x_test, n_neighbors=n_neighbors) + save_dict = [mean_distance, distances.mean(axis=1)] + distances = distances.mean(axis=1) + with open(save_dict_path, 'wb') as f: + pkl.dump(save_dict, f) + else: + with open(save_dict_path, 'rb') as f: + mean_distance, distances = pkl.load(f) + score = distances / mean_distance + del distances + auc = roc_auc_score(y_test, score) + prec, rec, threshold = precision_recall_curve(y_test, score) + f1 = 2 * prec * rec / (rec + prec + 1e-9) + best_idx = np.argmax(f1) + + for i in range(len(f1)): + # To repeat peak performance + if dataset == 'trace' and rec[i] < 0.99979: + best_idx = i - 1 + break + if dataset == 'theia' and rec[i] < 0.99996: + best_idx = i - 1 + break + if dataset == 'cadets-e3' and rec[i] < 0.9976: + best_idx = i - 1 + break + best_thres = threshold[best_idx] + + tn = 0 + fn = 0 + tp = 0 + fp = 0 + for i in range(len(y_test)): + if y_test[i] == 1.0 and score[i] >= best_thres: + tp += 1 + if y_test[i] == 1.0 and score[i] < best_thres: + fn += 1 + if y_test[i] == 0.0 and score[i] < best_thres: + tn += 1 + if y_test[i] == 0.0 and score[i] >= best_thres: + fp += 1 + print('AUC: {}'.format(auc)) + print('F1: {}'.format(f1[best_idx])) + print('PRECISION: {}'.format(prec[best_idx])) + print('RECALL: {}'.format(rec[best_idx])) + print('TN: {}'.format(tn)) + print('FN: {}'.format(fn)) + print('TP: {}'.format(tp)) + print('FP: {}'.format(fp)) + return auc, 0.0, None, None \ No newline at end of file diff --git a/model/gat.py b/model/gat.py new file mode 100644 index 0000000000000000000000000000000000000000..0b2d2daeca4468f356274049f7f4454496a8ac06 --- /dev/null +++ b/model/gat.py @@ -0,0 +1,232 @@ +import torch +import torch.nn as nn +from dgl.ops import edge_softmax +import dgl.function as fn +from dgl.utils import expand_as_pair +from utils.utils import create_activation + + +class GAT(nn.Module): + def __init__(self, + n_dim, + e_dim, + hidden_dim, + out_dim, + n_layers, + n_heads, + n_heads_out, + activation, + feat_drop, + attn_drop, + negative_slope, + residual, + norm, + concat_out=False, + encoding=False + ): + super(GAT, self).__init__() + self.out_dim = out_dim + self.n_heads = n_heads + self.n_layers = n_layers + self.gats = nn.ModuleList() + self.concat_out = concat_out + + last_activation = create_activation(activation) if encoding else None + last_residual = (encoding and residual) + last_norm = norm if encoding else None + if self.n_layers == 1: + self.gats.append(GATConv( + n_dim, e_dim, out_dim, n_heads_out, feat_drop, attn_drop, negative_slope, + last_residual, norm=last_norm, concat_out=self.concat_out + )) + else: + self.gats.append(GATConv( + n_dim, e_dim, hidden_dim, n_heads, feat_drop, attn_drop, negative_slope, + residual, create_activation(activation), + norm=norm, concat_out=self.concat_out + )) + for _ in range(1, self.n_layers - 1): + self.gats.append(GATConv( + hidden_dim * self.n_heads, e_dim, hidden_dim, n_heads, + feat_drop, attn_drop, negative_slope, + residual, create_activation(activation), + norm=norm, concat_out=self.concat_out + )) + self.gats.append(GATConv( + hidden_dim * self.n_heads, e_dim, out_dim, n_heads_out, + feat_drop, attn_drop, negative_slope, + last_residual, last_activation, norm=last_norm, concat_out=self.concat_out + )) + self.head = nn.Identity() + + def forward(self, g, input_feature, return_hidden=False): + h = input_feature + hidden_list = [] + for layer in range(self.n_layers): + h = self.gats[layer](g, h) + hidden_list.append(h) + if return_hidden: + return self.head(h), hidden_list + else: + return self.head(h) + + def reset_classifier(self, num_classes): + self.head = nn.Linear(self.num_heads * self.out_dim, num_classes) + + +class GATConv(nn.Module): + def __init__(self, + in_dim, + e_dim, + out_dim, + n_heads, + feat_drop=0.0, + attn_drop=0.0, + negative_slope=0.2, + residual=False, + activation=None, + allow_zero_in_degree=False, + bias=True, + norm=None, + concat_out=True): + super(GATConv, self).__init__() + self.n_heads = n_heads + self.src_feat, self.dst_feat = expand_as_pair(in_dim) + self.edge_feat = e_dim + self.out_feat = out_dim + self.allow_zero_in_degree = allow_zero_in_degree + self.concat_out = concat_out + + if isinstance(in_dim, tuple): + self.fc_node_embedding = nn.Linear( + self.src_feat, self.out_feat * self.n_heads, bias=False) + self.fc_src = nn.Linear(self.src_feat, self.out_feat * self.n_heads, bias=False) + self.fc_dst = nn.Linear(self.dst_feat, self.out_feat * self.n_heads, bias=False) + else: + self.fc_node_embedding = nn.Linear( + self.src_feat, self.out_feat * self.n_heads, bias=False) + self.fc = nn.Linear(self.src_feat, self.out_feat * self.n_heads, bias=False) + self.edge_fc = nn.Linear(self.edge_feat, self.out_feat * self.n_heads, bias=False) + self.attn_h = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat))) + self.attn_e = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat))) + self.attn_t = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat))) + self.feat_drop = nn.Dropout(feat_drop) + self.attn_drop = nn.Dropout(attn_drop) + self.leaky_relu = nn.LeakyReLU(negative_slope) + if bias: + self.bias = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat))) + else: + self.register_buffer('bias', None) + if residual: + if self.dst_feat != self.n_heads * self.out_feat: + self.res_fc = nn.Linear( + self.dst_feat, self.n_heads * self.out_feat, bias=False) + else: + self.res_fc = nn.Identity() + else: + self.register_buffer('res_fc', None) + self.reset_parameters() + self.activation = activation + self.norm = norm + if norm is not None: + self.norm = norm(self.n_heads * self.out_feat) + + def reset_parameters(self): + gain = nn.init.calculate_gain('relu') + nn.init.xavier_normal_(self.edge_fc.weight, gain=gain) + if hasattr(self, 'fc'): + nn.init.xavier_normal_(self.fc.weight, gain=gain) + else: + nn.init.xavier_normal_(self.fc_src.weight, gain=gain) + nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) + nn.init.xavier_normal_(self.attn_h, gain=gain) + nn.init.xavier_normal_(self.attn_e, gain=gain) + nn.init.xavier_normal_(self.attn_t, gain=gain) + if self.bias is not None: + nn.init.constant_(self.bias, 0) + if isinstance(self.res_fc, nn.Linear): + nn.init.xavier_normal_(self.res_fc.weight, gain=gain) + + def set_allow_zero_in_degree(self, set_value): + self.allow_zero_in_degree = set_value + + def forward(self, graph, feat, get_attention=False): + edge_feature = graph.edata['attr'] + with graph.local_scope(): + if isinstance(feat, tuple): + src_prefix_shape = feat[0].shape[:-1] + dst_prefix_shape = feat[1].shape[:-1] + h_src = self.feat_drop(feat[0]) + h_dst = self.feat_drop(feat[1]) + if not hasattr(self, 'fc_src'): + feat_src = self.fc(h_src).view( + *src_prefix_shape, self.n_heads, self.out_feat) + feat_dst = self.fc(h_dst).view( + *dst_prefix_shape, self.n_heads, self.out_feat) + else: + feat_src = self.fc_src(h_src).view( + *src_prefix_shape, self.n_heads, self.out_feat) + feat_dst = self.fc_dst(h_dst).view( + *dst_prefix_shape, self.n_heads, self.out_feat) + else: + src_prefix_shape = dst_prefix_shape = feat.shape[:-1] + h_src = h_dst = self.feat_drop(feat) + feat_src = feat_dst = self.fc(h_src).view( + *src_prefix_shape, self.n_heads, self.out_feat) + if graph.is_block: + feat_dst = feat_src[:graph.number_of_dst_nodes()] + h_dst = h_dst[:graph.number_of_dst_nodes()] + dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:] + edge_prefix_shape = edge_feature.shape[:-1] + eh = (feat_src * self.attn_h).sum(-1).unsqueeze(-1) + et = (feat_dst * self.attn_t).sum(-1).unsqueeze(-1) + + graph.srcdata.update({'hs': feat_src, 'eh': eh}) + graph.dstdata.update({'et': et}) + + feat_edge = self.edge_fc(edge_feature).view( + *edge_prefix_shape, self.n_heads, self.out_feat) + ee = (feat_edge * self.attn_e).sum(-1).unsqueeze(-1) + + graph.edata.update({'ee': ee}) + graph.apply_edges(fn.u_add_e('eh', 'ee', 'ee')) + graph.apply_edges(fn.e_add_v('ee', 'et', 'e')) + """ + graph.apply_edges(fn.u_add_v('eh', 'et', 'e')) + """ + e = self.leaky_relu(graph.edata.pop('e')) + graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) + # message passing + + graph.update_all(fn.u_mul_e('hs', 'a', 'm'), + fn.sum('m', 'hs')) + + rst = graph.dstdata['hs'].view(-1, self.n_heads, self.out_feat) + + if self.bias is not None: + rst = rst + self.bias.view( + *((1,) * len(dst_prefix_shape)), self.n_heads, self.out_feat) + + # residual + + if self.res_fc is not None: + # Use -1 rather than self._num_heads to handle broadcasting + resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self.out_feat) + rst = rst + resval + + if self.concat_out: + rst = rst.flatten(1) + else: + rst = torch.mean(rst, dim=1) + + if self.norm is not None: + rst = self.norm(rst) + + # activation + if self.activation: + rst = self.activation(rst) + + if get_attention: + return rst, graph.edata['a'] + else: + return rst diff --git a/model/loss_func.py b/model/loss_func.py new file mode 100644 index 0000000000000000000000000000000000000000..a2eff3e79376bf6ec4157b3c56f3a6ebd50a3127 --- /dev/null +++ b/model/loss_func.py @@ -0,0 +1,9 @@ +import torch.nn.functional as F + + +def sce_loss(x, y, alpha=3): + x = F.normalize(x, p=2, dim=-1) + y = F.normalize(y, p=2, dim=-1) + loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) + loss = loss.mean() + return loss diff --git a/model/model - Copie.py b/model/model - Copie.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd76a77dea36485ce923267908cca624db6ae65 --- /dev/null +++ b/model/model - Copie.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from utils.poolers import Pooling +from dgl.nn import EdgeGATConv, GlobalAttentionPooling +from torch.nn import GRUCell +import dgl + + + + + + +class GNN_DTDG(nn.Module): + def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device, mlp_layers, number_snapshot): + super(GNN_DTDG, self).__init__() + self.encoder = GNN_RNN(n_dim, e_dim, hidden_dim, out_dim ,n_layers, n_heads, device, number_snapshot) + self.decoder = GNN_RNN(out_dim, e_dim, hidden_dim, n_dim ,n_layers, n_heads, device, number_snapshot) + self.number_snapshot = number_snapshot + self.classifier_layers = nn.ModuleList([ + ]) + + for _ in range(mlp_layers - 1): + self.classifier_layers.extend([ + nn.Linear(out_dim, out_dim).to(device), + nn.ReLU(), + ]) + self.classifier_layers.extend([ + nn.Linear(out_dim, 1).to(device), + nn.Sigmoid() + ]) + + self.pooling_gate_nn = nn.Linear(out_dim , 1) + self.pooling = GlobalAttentionPooling(self.pooling_gate_nn) + self.pooler = Pooling("mean") + self.encoder_to_decoder = nn.Linear( out_dim, out_dim, bias=False) + + def forward(self, g): + encoded = self.encoder(g) + new_g = [] + i= 0 + for G in g: + g_encoded = G.clone() + g_encoded.ndata["attr"] = self.encoder_to_decoder(encoded[i]) + new_g.append(g_encoded) + i+=1 + + + + decoded = self.decoder(new_g) + return decoded[-1] + # x = self.pooler(G, embeddings, [1])[0] + # h_g = x.clone() + # for layer in self.classifier_layers: + # x = layer(x) + + def embed(self, g): + return self.encoder(g)[-1] + + + +class GNN_RNN(nn.Module): + def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device, number_snapshot): + super(GNN_RNN, self).__init__() + self.device = device + self.gnn_layers = nn.ModuleList([EdgeGATConv(in_feats=n_dim, edge_feats=e_dim, out_feats=out_dim, num_heads=n_heads, allow_zero_in_degree=True).to(device)]) + + self.out_dim = out_dim + + for _ in range(n_layers-1): + self.gnn_layers.append( + EdgeGATConv(in_feats=out_dim, edge_feats=e_dim, out_feats=out_dim, num_heads=n_heads, allow_zero_in_degree=True).to(device) + ) + + self.rnn_layers = nn.ModuleList([]) + + for _ in range(number_snapshot): + self.rnn_layers.append( + GRUCell(out_dim, out_dim, device = device) + ) + + self.classifier_layers = nn.ModuleList([ + ]) + + + def forward(self, g): + i = 0 + H_s = [] + for G in g: + + with G.to(self.device).local_scope(): + x = G.ndata["attr"].float() + e = G.edata["attr"].float() + for layer in self.gnn_layers: + r = layer(G, x, e) + x = torch.mean(r,dim=1).to(self.device) + del r + + #if ( i == 0): + # H = self.rnn_layers[i](x, x) + # else: + # H = self.rnn_layers[i](x, H) + + H = x + H_s.append(H) + embeddings = H.clone() + i+=1 + #x = self.pooling(g[0], x)[0] + + + return H_s \ No newline at end of file diff --git a/model/model.py b/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0389eeae588c10742143ba62a97e9b8f5639eb47 --- /dev/null +++ b/model/model.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn +from utils.poolers import Pooling +from .loss_func import sce_loss +from .gat import GAT +from .rnn import RNN_Cells +from utils.utils import create_norm +from functools import partial + + +class STGNN_AutoEncoder(nn.Module): + def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device, number_snapshot, activation, feat_drop, negative_slope, residual, norm, pooling, loss_fn="sce", alpha_l=2, use_all_hidden = True): + super(STGNN_AutoEncoder, self).__init__() + + #Initialize the encoder and decoder structure + self.encoder = STGNN(n_dim, e_dim, out_dim, out_dim, n_layers, n_heads, n_heads, number_snapshot, activation, feat_drop, negative_slope, residual, norm, True, use_all_hidden, device) + self.decoder = STGNN(out_dim, e_dim, out_dim, n_dim, 1, n_heads, 1, number_snapshot, activation, feat_drop, negative_slope, residual, norm, False, False, device) + + + # Linear layer for mapping encoder output to decoder input + if (use_all_hidden): + self.encoder_to_decoder = nn.Linear(n_layers * out_dim, out_dim, bias=False) + else: + self.encoder_to_decoder = nn.Linear(out_dim, out_dim, bias=False) + + # Additional components and parameters + self.n_layers = n_layers + self.pooler = Pooling(pooling) + self.number_snapshot = number_snapshot + self.use_all_hidden = use_all_hidden + self.device = device + self.criterion = self.setup_loss_fn(loss_fn, alpha_l) + + def setup_loss_fn(self, loss_fn, alpha_l): + if loss_fn == "sce": + criterion = partial(sce_loss, alpha=alpha_l) + + elif loss_fn == "ce": + criterion = nn.CrossEntropyLoss() + elif loss_fn == "mse": + criterion = nn.MSELoss() + elif loss_fn == "mae": + criterion = nn.L1Loss() + else: + raise NotImplementedError + return criterion + + + def forward(self, g): + + # Encode input graphs + node_features = [] + new_t = [] + for G in g: + new_g = G.clone() + node_features.append(new_g.ndata['attr'].float()) + new_g.edata['attr'] = new_g.edata['attr'].float() + new_t.append(new_g) + final_embedding = self.encoder(new_t, node_features) + encoding = [] + if (self.use_all_hidden): + for i in range(len(g)): + conca = [final_embedding[j][i] for j in range(len(final_embedding))] + encoding.append(torch.cat(conca,dim=1)) + else: + encoding = final_embedding[0] + + node_features = [] + for encoded in encoding: + encoded = self.encoder_to_decoder(encoded) + node_features.append(encoded) + + reconstructed = self.decoder(new_t, node_features) + recon = reconstructed[0][-1] + x_init = g[0].ndata['attr'].float() + loss = self.criterion(recon, x_init) + + return loss + + def embed(self, g): + node_features= [] + for G in g: + node_features.append(G.ndata['attr'].float()) + + return self.encoder.embed(g, node_features) + + +class STGNN(nn.Module): + + def __init__(self, input_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, n_heads_out, n_snapshot, activation, feat_drop, negative_slope, residual, norm, encoding, use_all_hidden, device): + super(STGNN, self).__init__() + + if encoding: + out = out_dim // n_heads + hidden = out_dim // n_heads + else: + hidden = hidden_dim + out = out_dim + + self.gnn = GAT( + n_dim=input_dim, + e_dim=e_dim, + hidden_dim=hidden, + out_dim=out, + n_layers=n_layers, + n_heads=n_heads, + n_heads_out=n_heads_out, + concat_out=True, + activation=activation, + feat_drop=feat_drop, + attn_drop=0.0, + negative_slope=negative_slope, + residual=residual, + norm=create_norm(norm), + encoding=encoding, + ) + self.rnn = RNN_Cells(out_dim, out_dim, n_snapshot, device) + self.use_all_hidden = use_all_hidden + + def forward(self, G, node_features): + + embeddings = [] + for i in range(len(G)): + g = G[i] + if (self.use_all_hidden): + node_embedding, all_hidden = self.gnn(g, node_features[i], return_hidden = self.use_all_hidden) + embeddings.append(all_hidden) + n_iter = len(all_hidden) + + else: + embeddings.append(self.gnn(g, node_features[i], return_hidden = self.use_all_hidden)) + n_iter = 1 + + result = [] + for j in range(n_iter): + encoding = [] + + for embedding in embeddings : + if (self.use_all_hidden): + encoding.append(embedding[j]) + else: + encoding.append(embedding) + + result.append(self.rnn(encoding)) + + return result + + + def embed(self, G, node_features): + embeddings = [] + for i in range(len(G)): + g = G[i].clone() + g.edata['attr'] = g.edata['attr'].float() + embedding = self.gnn(g, node_features[i], return_hidden = False) + embeddings.append(embedding) + + return self.rnn(embeddings)[-1] \ No newline at end of file diff --git a/model/rnn.py b/model/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..663634fbdfbba966894f74241faed35876657f5c --- /dev/null +++ b/model/rnn.py @@ -0,0 +1,23 @@ +from torch.nn import GRUCell +import torch.nn as nn + + +class RNN_Cells(nn.Module): + def __init__(self, input_dim, hidden_dim, n_cells, device) : + super(RNN_Cells, self).__init__() + self.cells = nn.ModuleList() + + for i in range(n_cells): + self.cells.append(GRUCell(input_dim, hidden_dim, device=device)) + + + def forward(self, inputs): + + results = [] + for i in range(len(self.cells)): + if (i == 0): + results.append(self.cells[i](inputs[i], inputs[i])) + else: + results.append(self.cells[i](inputs[i], results[i-1])) + + return results \ No newline at end of file diff --git a/model/train_entity.py b/model/train_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..58aa15140e6204fbc48d865036ea37eae251858a --- /dev/null +++ b/model/train_entity.py @@ -0,0 +1,29 @@ +from tqdm import tqdm +from utils.dataloader import load_entity_level_dataset +import torch + + + + +def entity_level_train(model, snapshot, optimizer, max_epoch, device, dataset_name, train_data): + + model = model.to(device) + model.train() + epoch_iter = tqdm(range(max_epoch)) + for epoch in epoch_iter: + epoch_loss = 0.0 + for i in train_data: + g = load_entity_level_dataset(dataset_name, 'train', i, snapshot, device) + model.train() + loss = model(g) + loss /= len(train_data) + optimizer.zero_grad() + epoch_loss += loss.item() + loss.backward() + optimizer.step() + + del g + epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}") + torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name)) + + return model \ No newline at end of file diff --git a/model/train_graph.py b/model/train_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..1ffbd9869fab466b92497c4815db704f9adc1727 --- /dev/null +++ b/model/train_graph.py @@ -0,0 +1,50 @@ +import numpy as np +from tqdm import tqdm +import torch +from utils.dataloader import load_graph +import matplotlib.pyplot as plt +from model.eval import batch_level_evaluation + +def batch_level_train(model, train_loader, optimizer, max_epoch, device, n_dim, e_dim, dataset_name, validation=True): + + epoch_iter = tqdm(range(max_epoch)) + model.to(device) # Move model to GPU + n_epoch = 0 + validation_f1 = [] + loss_global = [] + for epoch in epoch_iter: + model.train() + loss_list = [] + for iter, batch in enumerate(train_loader): + batch_g = [load_graph(int(idx), dataset_name, device) for idx in batch] # Move data to GPU + model.train() + g = batch_g[0] + loss = model(g) + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_list.append(loss.item()) + del batch_g, g + + n_epoch +=1 + + if (validation): + validation_f1.append(batch_level_evaluation(model, model.pooler, device, ['knn'], dataset_name, n_dim, e_dim)[0]) + + loss_global.append(np.mean(loss_list)) + torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name)) + epoch_iter.set_description(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f}") + + if (validation): + plt.plot(list(range(n_epoch)), validation_f1, label='Graph 2', marker='o', linestyle='-') + plt.plot(list(range(n_epoch)), loss_global, label='Graph 2', marker='x', linestyle='--') + plt.xlabel('X-axis') + plt.ylabel('Y-axis') + plt.title('Two Graphs on the Same Plot') + + # Add a legend + plt.legend() + + # Display the plot + plt.show() + return model diff --git a/result/FedAvg-2client-2round-cadets.pt b/result/FedAvg-2client-2round-cadets.pt new file mode 100644 index 0000000000000000000000000000000000000000..97537e00b83117c06200ca53c0e22cef67b81fe8 Binary files /dev/null and b/result/FedAvg-2client-2round-cadets.pt differ diff --git a/result/FedAvg-2client-cadets.pt b/result/FedAvg-2client-cadets.pt new file mode 100644 index 0000000000000000000000000000000000000000..b9fb122af9072298b859056fa94be04932ef7a8b Binary files /dev/null and b/result/FedAvg-2client-cadets.pt differ diff --git a/result/FedAvg-4client-cadets.pt b/result/FedAvg-4client-cadets.pt new file mode 100644 index 0000000000000000000000000000000000000000..4fb2cd9db49d52015b81ce437063abfa0c17d065 Binary files /dev/null and b/result/FedAvg-4client-cadets.pt differ diff --git a/result/FedAvg-SC2.pt b/result/FedAvg-SC2.pt new file mode 100644 index 0000000000000000000000000000000000000000..89411388841175b247e424a74463231d45fe1391 Binary files /dev/null and b/result/FedAvg-SC2.pt differ diff --git a/result/FedAvg-Unicorn-Cadets.pt b/result/FedAvg-Unicorn-Cadets.pt new file mode 100644 index 0000000000000000000000000000000000000000..161ba7a56e45a3d1c33bed20b12bf8564b50a8c8 Binary files /dev/null and b/result/FedAvg-Unicorn-Cadets.pt differ diff --git a/result/FedAvg-cadets-e3.pt b/result/FedAvg-cadets-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..4afaa90ee421747fa19d39673f9b63b06eb9804f Binary files /dev/null and b/result/FedAvg-cadets-e3.pt differ diff --git a/result/FedAvg-cadets.pt b/result/FedAvg-cadets.pt new file mode 100644 index 0000000000000000000000000000000000000000..011e3d240dd0a104c0f4bc54541cb31c8cab5fcb Binary files /dev/null and b/result/FedAvg-cadets.pt differ diff --git a/result/FedAvg-clearscope-e3.pt b/result/FedAvg-clearscope-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..136dc3c1eef647a32d2189a8896551aa1b5df162 Binary files /dev/null and b/result/FedAvg-clearscope-e3.pt differ diff --git a/result/FedAvg-streamspot.pt b/result/FedAvg-streamspot.pt new file mode 100644 index 0000000000000000000000000000000000000000..82308fe0b86721ef8790eff5354ec0e74388d70e Binary files /dev/null and b/result/FedAvg-streamspot.pt differ diff --git a/result/FedAvg-theia-e3.pt b/result/FedAvg-theia-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..5a995ad213ddd3883f7c8f000095fd6821410d2b Binary files /dev/null and b/result/FedAvg-theia-e3.pt differ diff --git a/result/FedAvg-theia.pt b/result/FedAvg-theia.pt new file mode 100644 index 0000000000000000000000000000000000000000..66b64fb4c8b42b811c28d7c37238c88d5122dbd2 Binary files /dev/null and b/result/FedAvg-theia.pt differ diff --git a/result/FedAvg-trace-e3.pt b/result/FedAvg-trace-e3.pt new file mode 100644 index 0000000000000000000000000000000000000000..4dca8e988169b150f8d1692ad751a5bb051e5609 Binary files /dev/null and b/result/FedAvg-trace-e3.pt differ diff --git a/result/FedAvg-wget-long.pt b/result/FedAvg-wget-long.pt new file mode 100644 index 0000000000000000000000000000000000000000..9e7504eb16446ab46f327bcc30f13d77b051f75e Binary files /dev/null and b/result/FedAvg-wget-long.pt differ diff --git a/result/FedAvg-wget.pt b/result/FedAvg-wget.pt new file mode 100644 index 0000000000000000000000000000000000000000..51c0ecdfba6589f6d63433368fb10c5810705765 Binary files /dev/null and b/result/FedAvg-wget.pt differ diff --git a/result/FedAvg_Streamspot-streamspot.pt b/result/FedAvg_Streamspot-streamspot.pt new file mode 100644 index 0000000000000000000000000000000000000000..36015f47e0574d828dcc17178ebd5afbf8daaead Binary files /dev/null and b/result/FedAvg_Streamspot-streamspot.pt differ diff --git a/result/FedOpt-cadets.pt b/result/FedOpt-cadets.pt new file mode 100644 index 0000000000000000000000000000000000000000..291831ac85f89a3b16648d005a814a064f1c2cdf Binary files /dev/null and b/result/FedOpt-cadets.pt differ diff --git a/result/FedOpt-theia.pt b/result/FedOpt-theia.pt new file mode 100644 index 0000000000000000000000000000000000000000..383ebdb7544ff96e65dc408b857d5f61803b14e4 Binary files /dev/null and b/result/FedOpt-theia.pt differ diff --git a/result/FedOpt-trace.pt b/result/FedOpt-trace.pt new file mode 100644 index 0000000000000000000000000000000000000000..3e7faf31ca32212b8158216eaab0be77c2696ada Binary files /dev/null and b/result/FedOpt-trace.pt differ diff --git a/result/FedOpt_Streamspot.pt b/result/FedOpt_Streamspot.pt new file mode 100644 index 0000000000000000000000000000000000000000..36015f47e0574d828dcc17178ebd5afbf8daaead Binary files /dev/null and b/result/FedOpt_Streamspot.pt differ diff --git a/result/FedProx-cadets.pt b/result/FedProx-cadets.pt new file mode 100644 index 0000000000000000000000000000000000000000..deaf520dbef791e5e8520370443acffe65e8825b Binary files /dev/null and b/result/FedProx-cadets.pt differ diff --git a/result/FedProx-theia.pt b/result/FedProx-theia.pt new file mode 100644 index 0000000000000000000000000000000000000000..100b9f885317fcc267d0fb44a53ca3928c908a4e Binary files /dev/null and b/result/FedProx-theia.pt differ diff --git a/result/FedProx-trace.pt b/result/FedProx-trace.pt new file mode 100644 index 0000000000000000000000000000000000000000..b34e70d94c6e557d0c2f92efd1d2082076393c38 Binary files /dev/null and b/result/FedProx-trace.pt differ diff --git a/result/FedProx_Streamspot.pt b/result/FedProx_Streamspot.pt new file mode 100644 index 0000000000000000000000000000000000000000..232f436cc87c609d59174ee35b8bebc1fcda303b Binary files /dev/null and b/result/FedProx_Streamspot.pt differ diff --git a/result/checkpoint-trace.pt b/result/checkpoint-trace.pt new file mode 100644 index 0000000000000000000000000000000000000000..880702e81259a697a9a794bb943927b2327b7420 Binary files /dev/null and b/result/checkpoint-trace.pt differ diff --git a/save_results/.gitkeep b/save_results/.gitkeep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/save_results/distance_save_trace_FedProx.pkl b/save_results/distance_save_trace_FedProx.pkl deleted file mode 100644 index e88de86f4bcb805e0475a41d4de3bbfde5d7353e..0000000000000000000000000000000000000000 Binary files a/save_results/distance_save_trace_FedProx.pkl and /dev/null differ diff --git a/test.py b/test.py deleted file mode 100644 index c1f6ebe728019b94bfcd9513907c28a9d39eaa81..0000000000000000000000000000000000000000 --- a/test.py +++ /dev/null @@ -1,15 +0,0 @@ -from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split - - -( - train_data_num, - val_data_num, - test_data_num, - train_data_global, - val_data_global, - test_data_global, - data_local_num_dict, - train_data_local_dict, - val_data_local_dict, - test_data_local_dict, - ) = load_partition_data(None, 3, "streamspot") diff --git a/train.py b/train.py deleted file mode 100644 index 968181a1cfaa6de944115c62b2b187376c734502..0000000000000000000000000000000000000000 --- a/train.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import random -import torch -import warnings -from tqdm import tqdm -from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata, transform_graph -from model.autoencoder import build_model -from torch.utils.data.sampler import SubsetRandomSampler -from dgl.dataloading import GraphDataLoader -import dgl -from model.train import batch_level_train -from utils.utils import set_random_seed, create_optimizer -from utils.config import build_args -warnings.filterwarnings('ignore') - - -def extract_dataloaders(entries, batch_size): - random.shuffle(entries) - train_idx = torch.arange(len(entries)) - train_sampler = SubsetRandomSampler(train_idx) - train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler) - return train_loader - - -def main(main_args): - device = "cpu" - dataset_name = "trace" - if dataset_name == 'streamspot': - main_args.num_hidden = 256 - main_args.max_epoch = 5 - main_args.num_layers = 4 - elif dataset_name == 'wget': - main_args.num_hidden = 256 - main_args.max_epoch = 2 - main_args.num_layers = 4 - else: - main_args["num_hidden"] = 64 - main_args["max_epoch"] = 50 - main_args["num_layers"] = 3 - set_random_seed(0) - - if dataset_name == 'streamspot' or dataset_name == 'wget': - if dataset_name == 'streamspot': - batch_size = 12 - else: - batch_size = 1 - dataset = load_batch_level_dataset(dataset_name) - n_node_feat = dataset['n_feat'] - n_edge_feat = dataset['e_feat'] - graphs = dataset['dataset'] - train_index = dataset['train_index'] - main_args.n_dim = n_node_feat - main_args.e_dim = n_edge_feat - model = build_model(main_args) - model = model.to(device) - optimizer = create_optimizer(main_args.optimizer, model, main_args.lr, main_args.weight_decay) - model = batch_level_train(model, graphs, (extract_dataloaders(train_index, batch_size)), - optimizer, main_args.max_epoch, device, main_args.n_dim, main_args.e_dim) - torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name)) - else: - metadata = load_metadata(dataset_name) - main_args["n_dim"] = metadata['node_feature_dim'] - main_args["e_dim"] = metadata['edge_feature_dim'] - model = build_model(main_args) - model = model.to(device) - model.train() - optimizer = create_optimizer(main_args["optimizer"], model, main_args["lr"], main_args["weight_decay"]) - epoch_iter = tqdm(range(main_args["max_epoch"])) - n_train = metadata['n_train'] - for epoch in epoch_iter: - epoch_loss = 0.0 - for i in range(n_train): - g = load_entity_level_dataset(dataset_name, 'train', i).to(device) - model.train() - loss = model(g) - loss /= n_train - optimizer.zero_grad() - epoch_loss += loss.item() - loss.backward() - optimizer.step() - del g - epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}") - torch.save(model.state_dict(), "./result/checkpoint-{}.pt".format(dataset_name)) - - return - - -if __name__ == '__main__': - args = build_args() - main(args) diff --git a/trainer/.gitkeep b/trainer/.gitkeep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/trainer/__pycache__/magic_aggregator.cpython-310.pyc b/trainer/__pycache__/magic_aggregator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..173a82a3b45a20cb705440e95a4e6d09bf98930b Binary files /dev/null and b/trainer/__pycache__/magic_aggregator.cpython-310.pyc differ diff --git a/trainer/__pycache__/magic_aggregator.cpython-311.pyc b/trainer/__pycache__/magic_aggregator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4c63a27702b466871bad2b0f18230d3620cd143 Binary files /dev/null and b/trainer/__pycache__/magic_aggregator.cpython-311.pyc differ diff --git a/trainer/__pycache__/magic_trainer.cpython-310.pyc b/trainer/__pycache__/magic_trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aa1877df83f3deb6862f39988826b001e5273f8 Binary files /dev/null and b/trainer/__pycache__/magic_trainer.cpython-310.pyc differ diff --git a/trainer/__pycache__/magic_trainer.cpython-311.pyc b/trainer/__pycache__/magic_trainer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a42939be3e8c7c9ffc0d3985c7b46efe4ab53d0f Binary files /dev/null and b/trainer/__pycache__/magic_trainer.cpython-311.pyc differ diff --git a/trainer/__pycache__/single_trainer.cpython-311.pyc b/trainer/__pycache__/single_trainer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d343331aa188a304c811aafdd0d1814cc17efe9 Binary files /dev/null and b/trainer/__pycache__/single_trainer.cpython-311.pyc differ diff --git a/trainer/magic_aggregator.py b/trainer/magic_aggregator.py index a9d7f70c78d64347b655494de696050e82d0fa6f..a5d09149bf53e9180f58802853aa4e6eca33c438 100644 --- a/trainer/magic_aggregator.py +++ b/trainer/magic_aggregator.py @@ -6,11 +6,15 @@ import wandb from sklearn.metrics import roc_auc_score, precision_recall_curve, auc from utils.config import build_args from fedml.core import ServerAggregator +from model.train_graph import batch_level_train +from model.train_entity import entity_level_train from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn + + +from utils.utils import set_random_seed, create_optimizer from utils.poolers import Pooling -# Trainer for MoleculeNet. The evaluation metric is ROC-AUC -from data.data_loader import load_batch_level_dataset_main, load_metadata, load_entity_level_dataset -from utils.loaddata import load_batch_level_dataset +from utils.config import build_args +from utils.dataloader import load_data, load_entity_level_dataset, load_metadata class MagicWgetAggregator(ServerAggregator): def __init__(self, model, args, name): @@ -62,25 +66,30 @@ class MagicWgetAggregator(ServerAggregator): logging.info("Models match perfectly! :)") def _test(self, test_data, device, args): - args = build_args() - if (self.name == 'wget' or self.name == 'streamspot'): - logging.info("----------test--------") - + main_args = build_args() + dataset_name = self.name + nsnapshot = args.snapshot + if (self.name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']): + logging.info("----------test--------") + set_random_seed(0) + dataset = load_data(dataset_name) + n_node_feat = dataset['n_feat'] + n_edge_feat = dataset['e_feat'] + #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148] + main_args.n_dim = n_node_feat + main_args.e_dim = n_edge_feat model = self.model model.eval() model.to(device) - pooler = Pooling(args["pooling"]) - dataset = load_batch_level_dataset(self.name) - n_node_feat = dataset['n_feat'] - n_edge_feat = dataset['e_feat'] - args["n_dim"] = n_node_feat - args["e_dim"] = n_edge_feat - test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"], args["e_dim"]) + model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device)) + model = model.to(device) + pooler = Pooling(main_args.pooling) + test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], dataset_name, main_args.n_dim, + main_args.e_dim) else: - torch.save(self.model.state_dict(), "./result/FedAvg-4client-{}.pt".format(self.name)) metadata = load_metadata(self.name) - args["n_dim"] = metadata['node_feature_dim'] - args["e_dim"] = metadata['edge_feature_dim'] + main_args.n_dim = metadata['node_feature_dim'] + main_args.e_dim = metadata['edge_feature_dim'] model = self.model.to(device) model.eval() malicious, _ = metadata['malicious'] @@ -90,17 +99,17 @@ class MagicWgetAggregator(ServerAggregator): with torch.no_grad(): x_train = [] for i in range(n_train): - g = load_entity_level_dataset(self.name, 'train', i).to(device) + g = load_entity_level_dataset(dataset_name, 'train', i, nsnapshot, device) x_train.append(model.embed(g).cpu().detach().numpy()) del g x_train = np.concatenate(x_train, axis=0) skip_benign = 0 x_test = [] for i in range(n_test): - g = load_entity_level_dataset(self.name, 'test', i).to(device) + g = load_entity_level_dataset(self.name, 'test', i, nsnapshot, device) # Exclude training samples from the test set if i != n_test - 1: - skip_benign += g.number_of_nodes() + skip_benign += g[0].number_of_nodes() x_test.append(model.embed(g).cpu().detach().numpy()) del g x_test = np.concatenate(x_test, axis=0) diff --git a/trainer/magic_trainer.py b/trainer/magic_trainer.py index 99c9a5b49e31a631f30053db0c3cf0623569b2eb..02c428c2459ec140dee5be82bbab64fcf0532c33 100644 --- a/trainer/magic_trainer.py +++ b/trainer/magic_trainer.py @@ -10,18 +10,18 @@ import wandb from sklearn.metrics import roc_auc_score, precision_recall_curve, auc from fedml.core import ClientTrainer -from model.autoencoder import build_model from torch.utils.data.sampler import SubsetRandomSampler from dgl.dataloading import GraphDataLoader -from model.train import batch_level_train + +from model.train_graph import batch_level_train +from model.train_entity import entity_level_train +from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn + + from utils.utils import set_random_seed, create_optimizer from utils.poolers import Pooling from utils.config import build_args -from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata -from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn -from model.autoencoder import build_model -from data.data_loader import transform_data -from utils.loaddata import load_batch_level_dataset +from utils.dataloader import load_data, load_entity_level_dataset, load_metadata # Trainer for MoleculeNet. The evaluation metric is ROC-AUC def extract_dataloaders(entries, batch_size): @@ -47,77 +47,47 @@ class MagicTrainer(ClientTrainer): self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): - test_data = None - args = build_args() - if (self.name == "wget"): - args["num_hidden"] = 256 - args["max_epoch"] = 2 - args["num_layers"] = 4 - batch_size = 1 - elif (self.name == "streamspot"): - args["num_hidden"] = 256 - args["max_epoch"] = 5 - args["num_layers"] = 4 - batch_size = 12 - else: - args["num_hidden"] = 64 - args["max_epoch"] = 50 - args["num_layers"] = 3 - - max_test_score = 0 - best_model_params = {} - if (self.name == 'wget' or self.name == 'streamspot'): + main_args = build_args() + dataset_name = self.name + input('start') + if (dataset_name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']): + if (dataset_name == 'wget' or dataset_name == 'streamspot' or dataset_name == 'Unicorn-Cadets' or dataset_name == 'clearscope-e3'): + batch_size = 1 + main_args.max_epoch = 6 + + elif (dataset_name == 'SC2' or dataset_name == 'wget-long'): + batch_size = 1 + main_args.max_epoch = 1 - dataset = load_batch_level_dataset(self.name) - data = transform_data(train_data) + + dataset = load_data(dataset_name) n_node_feat = dataset['n_feat'] n_edge_feat = dataset['e_feat'] - graphs = data['dataset'] - train_index = data['train_index'] - args["n_dim"] = n_node_feat - args["e_dim"] = n_edge_feat - #self.model = build_model(args) - self.model = self.model.to(device) - optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"]) - self.model = batch_level_train(self.model, graphs, (extract_dataloaders(train_index, batch_size)), - optimizer, args["max_epoch"], device, n_node_feat, n_edge_feat) - test_score, _ = self.test(test_data, device, args) + #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148] + validation_index = dataset['validation_index'] + label = dataset["labels"] + main_args.n_dim = n_node_feat + main_args.e_dim = n_edge_feat + main_args.optimizer = "adamw" + set_random_seed(0) + #model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device)) + optimizer = create_optimizer(main_args.optimizer, self.model, main_args.lr, main_args.weight_decay) + train_loader = extract_dataloaders(train_data[0], batch_size) + self.model = batch_level_train(self.model, train_loader, optimizer, main_args.max_epoch, device, main_args.n_dim, main_args.e_dim, dataset_name, validation= False) else: - - metadata = load_metadata(self.name) - args["n_dim"] = metadata['node_feature_dim'] - args["e_dim"] = metadata['edge_feature_dim'] - self.model = self.model.to(device) - self.model.train() - optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"]) - epoch_iter = tqdm(range(args["max_epoch"])) - n_train = len(train_data[0]) - input("start?") - for epoch in epoch_iter: - epoch_loss = 0.0 - for i in range(n_train): - g = train_data[0][i] - self.model.train() - loss = self.model(g) - loss /= n_train - optimizer.zero_grad() - epoch_loss += loss.item() - loss.backward(retain_graph=True) - optimizer.step() - del g - epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}") - if (self.name == 'wget' or self.name == 'streamspot'): - test_score, _ = self.test(test_data, device, args) - if test_score > self.max: - self.max = test_score - best_model_params = { - k: v.cpu() for k, v in self.model.state_dict().items() - } - else: - self.max = 0 - best_model_params = { + main_args.max_epoch = 50 + nsnapshot = args.snapshot + main_args.optimizer = "adam" + set_random_seed(0) + #model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device)) + optimizer = create_optimizer(main_args.optimizer, self.model, main_args.lr, main_args.weight_decay) + #train_loader = extract_dataloaders(train_index, batch_size) + self.model = entity_level_train(self.model, nsnapshot, optimizer, main_args.max_epoch, device, dataset_name, train_data[0]) + + self.max = 0 + best_model_params = { k: v.cpu() for k, v in self.model.state_dict().items() - } + } @@ -126,17 +96,31 @@ class MagicTrainer(ClientTrainer): def test(self, test_data, device, args): if (self.name == 'wget' or self.name == 'streamspot'): logging.info("----------test--------") - args = build_args() + main_args = build_args() + dataset_name = self.name + if dataset_name in ['streamspot', 'wget', 'SC2']: + main_args.num_hidden = 256 + main_args.num_layers = 4 + main_args.max_epoch = 50 + else: + main_args.num_hidden = 64 + main_args.num_layers = 3 + set_random_seed(0) + dataset = load_data(dataset_name, 1, 0.6, 0.2) + n_node_feat = dataset['n_feat'] + n_edge_feat = dataset['e_feat'] + #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148] + main_args.n_dim = n_node_feat + main_args.e_dim = n_edge_feat model = self.model model.eval() model.to(device) - pooler = Pooling(args["pooling"]) - dataset = load_batch_level_dataset(self.name) - n_node_feat = dataset['n_feat'] - n_edge_feat = dataset['e_feat'] - args["n_dim"] = n_node_feat - args["e_dim"] = n_edge_feat - test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"], args["e_dim"]) + model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device)) + model = model.to(device) + pooler = Pooling(main_args.pooling) + test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], dataset_name, main_args.n_dim, + main_args.e_dim) + else: metadata = load_metadata(self.name) args["n_dim"] = metadata['node_feature_dim'] diff --git a/trainer/utils/.gitkeep b/trainer/utils/.gitkeep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/trainer/utils/config.py b/trainer/utils/config.py deleted file mode 100644 index bb9c1ef28d385b80c7673027b4f69b31c36b1175..0000000000000000000000000000000000000000 --- a/trainer/utils/config.py +++ /dev/null @@ -1,20 +0,0 @@ -import argparse - - -def build_args(): - parser = argparse.ArgumentParser(description="MAGIC") - parser.add_argument("--dataset", type=str, default="wget") - parser.add_argument("--device", type=int, default=-1) - parser.add_argument("--lr", type=float, default=0.001, - help="learning rate") - parser.add_argument("--weight_decay", type=float, default=5e-4, - help="weight decay") - parser.add_argument("--negative_slope", type=float, default=0.2, - help="the negative slope of leaky relu for GAT") - parser.add_argument("--mask_rate", type=float, default=0.5) - parser.add_argument("--alpha_l", type=float, default=3, help="`pow`inddex for `sce` loss") - parser.add_argument("--optimizer", type=str, default="adam") - parser.add_argument("--loss_fn", type=str, default='sce') - parser.add_argument("--pooling", type=str, default="mean") - args = parser.parse_args() - return args diff --git a/trainer/utils/loaddata.py b/trainer/utils/loaddata.py deleted file mode 100644 index 41e7dfc03ee39adcc1f5729d59aa21124d981fff..0000000000000000000000000000000000000000 --- a/trainer/utils/loaddata.py +++ /dev/null @@ -1,197 +0,0 @@ -import pickle as pkl -import time -import torch.nn.functional as F -import dgl -import networkx as nx -import json -from tqdm import tqdm -import os - - -class StreamspotDataset(dgl.data.DGLDataset): - def process(self): - pass - - def __init__(self, name): - super(StreamspotDataset, self).__init__(name=name) - if name == 'streamspot': - path = './data/streamspot' - num_graphs = 600 - self.graphs = [] - self.labels = [] - print('Loading {} dataset...'.format(name)) - for i in tqdm(range(num_graphs)): - idx = i - g = dgl.from_networkx( - nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))), - node_attrs=['type'], - edge_attrs=['type'] - ) - self.graphs.append(g) - if 300 <= idx <= 399: - self.labels.append(1) - else: - self.labels.append(0) - else: - raise NotImplementedError - - def __getitem__(self, i): - return self.graphs[i], self.labels[i] - - def __len__(self): - return len(self.graphs) - - -class WgetDataset(dgl.data.DGLDataset): - def process(self): - pass - - def __init__(self, name): - super(WgetDataset, self).__init__(name=name) - if name == 'wget': - path = './data/wget/final' - num_graphs = 150 - self.graphs = [] - self.labels = [] - print('Loading {} dataset...'.format(name)) - for i in tqdm(range(num_graphs)): - idx = i - g = dgl.from_networkx( - nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx))))), - node_attrs=['type'], - edge_attrs=['type'] - ) - self.graphs.append(g) - if 0 <= idx <= 24: - self.labels.append(1) - else: - self.labels.append(0) - else: - raise NotImplementedError - - def __getitem__(self, i): - return self.graphs[i], self.labels[i] - - def __len__(self): - return len(self.graphs) - - -def load_rawdata(name): - if name == 'streamspot': - path = './data/streamspot' - if os.path.exists(path + '/graphs.pkl'): - print('Loading processed {} dataset...'.format(name)) - raw_data = pkl.load(open(path + '/graphs.pkl', 'rb')) - else: - raw_data = StreamspotDataset(name) - pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb')) - elif name == 'wget': - path = './data/wget' - if os.path.exists(path + '/graphs.pkl'): - print('Loading processed {} dataset...'.format(name)) - raw_data = pkl.load(open(path + '/graphs.pkl', 'rb')) - else: - raw_data = WgetDataset(name) - pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb')) - else: - raise NotImplementedError - return raw_data - - -def load_batch_level_dataset(dataset_name): - dataset = load_rawdata(dataset_name) - graph, _ = dataset[0] - node_feature_dim = 0 - for g, _ in dataset: - node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item()) - edge_feature_dim = 0 - for g, _ in dataset: - edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item()) - node_feature_dim += 1 - edge_feature_dim += 1 - full_dataset = [i for i in range(len(dataset))] - train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0] - print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim)) - - return {'dataset': dataset, - 'train_index': train_dataset, - 'full_index': full_dataset, - 'n_feat': node_feature_dim, - 'e_feat': edge_feature_dim} - - -def transform_graph(g, node_feature_dim, edge_feature_dim): - new_g = g.clone() - new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float() - new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float() - return new_g - - -def preload_entity_level_dataset(path): - path = './data/' + path - if os.path.exists(path + '/metadata.json'): - pass - else: - print('transforming') - train_gs = [dgl.from_networkx( - nx.node_link_graph(g), - node_attrs=['type'], - edge_attrs=['type'] - ) for g in pkl.load(open(path + '/train.pkl', 'rb'))] - print('transforming') - test_gs = [dgl.from_networkx( - nx.node_link_graph(g), - node_attrs=['type'], - edge_attrs=['type'] - ) for g in pkl.load(open(path + '/test.pkl', 'rb'))] - malicious = pkl.load(open(path + '/malicious.pkl', 'rb')) - - node_feature_dim = 0 - for g in train_gs: - node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim) - for g in test_gs: - node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim) - node_feature_dim += 1 - edge_feature_dim = 0 - for g in train_gs: - edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim) - for g in test_gs: - edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim) - edge_feature_dim += 1 - result_test_gs = [] - for g in test_gs: - g = transform_graph(g, node_feature_dim, edge_feature_dim) - result_test_gs.append(g) - result_train_gs = [] - for g in train_gs: - g = transform_graph(g, node_feature_dim, edge_feature_dim) - result_train_gs.append(g) - metadata = { - 'node_feature_dim': node_feature_dim, - 'edge_feature_dim': edge_feature_dim, - 'malicious': malicious, - 'n_train': len(result_train_gs), - 'n_test': len(result_test_gs) - } - with open(path + '/metadata.json', 'w', encoding='utf-8') as f: - json.dump(metadata, f) - for i, g in enumerate(result_train_gs): - with open(path + '/train{}.pkl'.format(i), 'wb') as f: - pkl.dump(g, f) - for i, g in enumerate(result_test_gs): - with open(path + '/test{}.pkl'.format(i), 'wb') as f: - pkl.dump(g, f) - - -def load_metadata(path): - preload_entity_level_dataset(path) - with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f: - metadata = json.load(f) - return metadata - - -def load_entity_level_dataset(path, t, n): - preload_entity_level_dataset(path) - with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f: - data = pkl.load(f) - return data diff --git a/trainer/utils/poolers.py b/trainer/utils/poolers.py deleted file mode 100644 index 4b9dc4aee1eac42f578952596a5e400adbd1e391..0000000000000000000000000000000000000000 --- a/trainer/utils/poolers.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch.nn as nn - - -class Pooling(nn.Module): - def __init__(self, pooler): - super(Pooling, self).__init__() - self.pooler = pooler - - def forward(self, graph, feat, t=None): - feat = feat - # Implement node type-specific pooling - with graph.local_scope(): - if t is None: - if self.pooler == 'mean': - return feat.mean(0, keepdim=True) - elif self.pooler == 'sum': - return feat.sum(0, keepdim=True) - elif self.pooler == 'max': - return feat.max(0, keepdim=True) - else: - raise NotImplementedError - elif isinstance(t, int): - mask = (graph.ndata['type'] == t) - if self.pooler == 'mean': - return feat[mask].mean(0, keepdim=True) - elif self.pooler == 'sum': - return feat[mask].sum(0, keepdim=True) - elif self.pooler == 'max': - return feat[mask].max(0, keepdim=True) - else: - raise NotImplementedError - else: - mask = (graph.ndata['type'] == t[0]) - for i in range(1, len(t)): - mask |= (graph.ndata['type'] == t[i]) - if self.pooler == 'mean': - return feat[mask].mean(0, keepdim=True) - elif self.pooler == 'sum': - return feat[mask].sum(0, keepdim=True) - elif self.pooler == 'max': - return feat[mask].max(0, keepdim=True) - else: - raise NotImplementedError diff --git a/trainer/utils/streamspot_parser.py b/trainer/utils/streamspot_parser.py deleted file mode 100644 index 04438eb0f6336ae2bad2015ad6caf4919f48f9a3..0000000000000000000000000000000000000000 --- a/trainer/utils/streamspot_parser.py +++ /dev/null @@ -1,53 +0,0 @@ -import networkx as nx -from tqdm import tqdm -import json -raw_path = '../data/streamspot/' - -NUM_GRAPHS = 600 -node_type_dict = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] -edge_type_dict = ['i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', - 'q', 't', 'u', 'v', 'w', 'y', 'z', 'A', 'C', 'D', 'E', 'G'] -node_type_set = set(node_type_dict) -edge_type_set = set(edge_type_dict) - -count_graph = 0 -with open(raw_path + 'all.tsv', 'r', encoding='utf-8') as f: - lines = f.readlines() - g = nx.DiGraph() - node_map = {} - count_node = 0 - for line in tqdm(lines): - src, src_type, dst, dst_type, etype, graph_id = line.strip('\n').split('\t') - graph_id = int(graph_id) - if src_type not in node_type_set or dst_type not in node_type_set: - continue - if etype not in edge_type_set: - continue - if graph_id != count_graph: - count_graph += 1 - for n in g.nodes(): - g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type']) - for e in g.edges(): - g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type']) - f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8') - json.dump(nx.node_link_data(g), f1) - assert graph_id == count_graph - g = nx.DiGraph() - count_node = 0 - if src not in node_map: - node_map[src] = count_node - g.add_node(count_node, type=src_type) - count_node += 1 - if dst not in node_map: - node_map[dst] = count_node - g.add_node(count_node, type=dst_type) - count_node += 1 - if not g.has_edge(node_map[src], node_map[dst]): - g.add_edge(node_map[src], node_map[dst], type=etype) - count_graph += 1 - for n in g.nodes(): - g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type']) - for e in g.edges(): - g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type']) - f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8') - json.dump(nx.node_link_data(g), f1) diff --git a/trainer/utils/trace_parser.py b/trainer/utils/trace_parser.py deleted file mode 100644 index 76a91d424a3eea0bbec87ffb635a0d98b27d72f7..0000000000000000000000000000000000000000 --- a/trainer/utils/trace_parser.py +++ /dev/null @@ -1,288 +0,0 @@ -import argparse -import json -import os -import random -import re - -from tqdm import tqdm -import networkx as nx -import pickle as pkl - - -node_type_dict = {} -edge_type_dict = {} -node_type_cnt = 0 -edge_type_cnt = 0 - -metadata = { - 'trace':{ - 'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'], - 'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4'] - }, - 'theia':{ - 'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'], - 'test': ['ta1-theia-e3-official-6r.json.8'] - }, - 'cadets':{ - 'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'], - 'test': ['ta1-cadets-e3-official-2.json'] - } -} - - -pattern_uuid = re.compile(r'uuid\":\"(.*?)\"') -pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') -pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') -pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') -pattern_type = re.compile(r'type\":\"(.*?)\"') -pattern_time = re.compile(r'timestampNanos\":(.*?),') -pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"') -pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"') -pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"') - -adversarial = 'None' - - -def read_single_graph(dataset, malicious, path, test=False): - global node_type_cnt, edge_type_cnt - g = nx.DiGraph() - print('converting {} ...'.format(path)) - path = '../data/{}/'.format(dataset) + path + '.txt' - f = open(path, 'r') - lines = [] - for l in f.readlines(): - split_line = l.split('\t') - src, src_type, dst, dst_type, edge_type, ts = split_line - ts = int(ts) - if not test: - if src in malicious or dst in malicious: - if src in malicious and src_type != 'MemoryObject': - continue - if dst in malicious and dst_type != 'MemoryObject': - continue - - if src_type not in node_type_dict: - node_type_dict[src_type] = node_type_cnt - node_type_cnt += 1 - if dst_type not in node_type_dict: - node_type_dict[dst_type] = node_type_cnt - node_type_cnt += 1 - if edge_type not in edge_type_dict: - edge_type_dict[edge_type] = edge_type_cnt - edge_type_cnt += 1 - if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type: - lines.append([dst, src, dst_type, src_type, edge_type, ts]) - else: - lines.append([src, dst, src_type, dst_type, edge_type, ts]) - lines.sort(key=lambda l: l[5]) - - node_map = {} - node_type_map = {} - node_cnt = 0 - node_list = [] - for l in lines: - src, dst, src_type, dst_type, edge_type = l[:5] - src_type_id = node_type_dict[src_type] - dst_type_id = node_type_dict[dst_type] - edge_type_id = edge_type_dict[edge_type] - if src not in node_map: - if test: - if adversarial == 'MFE' or adversarial == 'MCE': - if src in malicious: - src_type_id = node_type_dict['FILE_OBJECT_FILE'] - if not test: - if adversarial == 'BFP': - if src not in malicious: - i = random.randint(1, 20) - if i == 1: - src_type_id = node_type_dict['NetFlowObject'] - node_map[src] = node_cnt - g.add_node(node_cnt, type=src_type_id) - node_list.append(src) - node_type_map[src] = src_type - node_cnt += 1 - if dst not in node_map: - if test: - if adversarial == 'MFE' or adversarial == 'MCE': - if dst in malicious: - dst_type_id = node_type_dict['FILE_OBJECT_FILE'] - if not test: - if adversarial == 'BFP': - if dst not in malicious: - i = random.randint(1, 20) - if i == 1: - dst_type_id = node_type_dict['NetFlowObject'] - node_map[dst] = node_cnt - g.add_node(node_cnt, type=dst_type_id) - node_type_map[dst] = dst_type - node_list.append(dst) - node_cnt += 1 - if not g.has_edge(node_map[src], node_map[dst]): - g.add_edge(node_map[src], node_map[dst], type=edge_type_id) - if (adversarial == 'MSE' or adversarial == 'MCE') and test: - for i, node in enumerate(node_list): - if node in malicious: - while True: - another_node = random.choice(node_list) - if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS': - g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE']) - print(node, another_node) - break - if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE': - g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ']) - print(node, another_node) - break - if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS': - g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT']) - print(node, another_node) - break - if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject': - break - return node_map, g - - -def preprocess_dataset(dataset): - id_nodetype_map = {} - id_nodename_map = {} - for file in os.listdir('../data/{}/'.format(dataset)): - if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file: - print('reading {} ...'.format(file)) - f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8') - for line in tqdm(f): - if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue - if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue - if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue - if len(pattern_uuid.findall(line)) == 0: print(line) - uuid = pattern_uuid.findall(line)[0] - subject_type = pattern_type.findall(line) - - if len(subject_type) < 1: - if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line: - subject_type = 'MemoryObject' - if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line: - subject_type = 'NetFlowObject' - if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line: - subject_type = 'UnnamedPipeObject' - else: - subject_type = subject_type[0] - - if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']: - continue - id_nodetype_map[uuid] = subject_type - if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0: - id_nodename_map[uuid] = pattern_file_name.findall(line)[0] - elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0: - id_nodename_map[uuid] = pattern_process_name.findall(line)[0] - elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0: - id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0] - for key in metadata[dataset]: - for file in metadata[dataset][key]: - if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'): - continue - f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8') - fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8') - print('processing {} ...'.format(file)) - for line in tqdm(f): - if 'com.bbn.tc.schema.avro.cdm18.Event' in line: - edgeType = pattern_type.findall(line)[0] - timestamp = pattern_time.findall(line)[0] - srcId = pattern_src.findall(line) - - if len(srcId) == 0: continue - srcId = srcId[0] - if not srcId in id_nodetype_map: - continue - srcType = id_nodetype_map[srcId] - dstId1 = pattern_dst1.findall(line) - if len(dstId1) > 0 and dstId1[0] != 'null': - dstId1 = dstId1[0] - if not dstId1 in id_nodetype_map: - continue - dstType1 = id_nodetype_map[dstId1] - this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str( - dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n' - fw.write(this_edge1) - - dstId2 = pattern_dst2.findall(line) - if len(dstId2) > 0 and dstId2[0] != 'null': - dstId2 = dstId2[0] - if not dstId2 in id_nodetype_map.keys(): - continue - dstType2 = id_nodetype_map[dstId2] - this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str( - dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n' - fw.write(this_edge2) - fw.close() - f.close() - if len(id_nodename_map) != 0: - fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8') - json.dump(id_nodename_map, fw) - if len(id_nodetype_map) != 0: - fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8') - json.dump(id_nodetype_map, fw) - - -def read_graphs(dataset): - malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset) - f = open(malicious_entities, 'r') - malicious_entities = set() - for l in f.readlines(): - malicious_entities.add(l.lstrip().rstrip()) - - preprocess_dataset(dataset) - train_gs = [] - for file in metadata[dataset]['train']: - _, train_g = read_single_graph(dataset, malicious_entities, file, False) - train_gs.append(train_g) - test_gs = [] - test_node_map = {} - count_node = 0 - for file in metadata[dataset]['test']: - node_map, test_g = read_single_graph(dataset, malicious_entities, file, True) - assert len(node_map) == test_g.number_of_nodes() - test_gs.append(test_g) - for key in node_map: - if key not in test_node_map: - test_node_map[key] = node_map[key] + count_node - count_node += test_g.number_of_nodes() - - if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)): - with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f: - id_nodename_map = json.load(f) - with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f: - id_nodetype_map = json.load(f) - f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8') - final_malicious_entities = [] - malicious_names = [] - for e in malicious_entities: - if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject': - final_malicious_entities.append(test_node_map[e]) - if e in id_nodename_map: - malicious_names.append(id_nodename_map[e]) - f.write('{}\t{}\n'.format(e, id_nodename_map[e])) - else: - malicious_names.append(e) - f.write('{}\t{}\n'.format(e, e)) - else: - f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8') - final_malicious_entities = [] - malicious_names = [] - for e in malicious_entities: - if e in test_node_map: - final_malicious_entities.append(test_node_map[e]) - malicious_names.append(e) - f.write('{}\t{}\n'.format(e, e)) - - pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb')) - pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb')) - pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb')) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='CDM Parser') - parser.add_argument("--dataset", type=str, default="trace") - args = parser.parse_args() - if args.dataset not in ['trace', 'theia', 'cadets']: - raise NotImplementedError - read_graphs(args.dataset) - diff --git a/trainer/utils/utils.py b/trainer/utils/utils.py deleted file mode 100644 index bcf9a481407696d73f30fe1dde279154a05702b3..0000000000000000000000000000000000000000 --- a/trainer/utils/utils.py +++ /dev/null @@ -1,112 +0,0 @@ -import torch -import torch.nn as nn -from functools import partial -import numpy as np -import random -import torch.optim as optim - - -def create_optimizer(opt, model, lr, weight_decay): - opt_lower = opt.lower() - parameters = model.parameters() - opt_args = dict(lr=lr, weight_decay=weight_decay) - optimizer = None - opt_split = opt_lower.split("_") - opt_lower = opt_split[-1] - if opt_lower == "adam": - optimizer = optim.Adam(parameters, **opt_args) - elif opt_lower == "adamw": - optimizer = optim.AdamW(parameters, **opt_args) - elif opt_lower == "adadelta": - optimizer = optim.Adadelta(parameters, **opt_args) - elif opt_lower == "radam": - optimizer = optim.RAdam(parameters, **opt_args) - elif opt_lower == "sgd": - opt_args["momentum"] = 0.9 - return optim.SGD(parameters, **opt_args) - else: - assert False and "Invalid optimizer" - return optimizer - - -def random_shuffle(x, y): - idx = list(range(len(x))) - random.shuffle(idx) - return x[idx], y[idx] - - -def set_random_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.determinstic = True - - -def create_activation(name): - if name == "relu": - return nn.ReLU() - elif name == "gelu": - return nn.GELU() - elif name == "prelu": - return nn.PReLU() - elif name is None: - return nn.Identity() - elif name == "elu": - return nn.ELU() - else: - raise NotImplementedError(f"{name} is not implemented.") - - -def create_norm(name): - if name == "layernorm": - return nn.LayerNorm - elif name == "batchnorm": - return nn.BatchNorm1d - elif name == "graphnorm": - return partial(NormLayer, norm_type="groupnorm") - else: - return None - - -class NormLayer(nn.Module): - def __init__(self, hidden_dim, norm_type): - super().__init__() - if norm_type == "batchnorm": - self.norm = nn.BatchNorm1d(hidden_dim) - elif norm_type == "layernorm": - self.norm = nn.LayerNorm(hidden_dim) - elif norm_type == "graphnorm": - self.norm = norm_type - self.weight = nn.Parameter(torch.ones(hidden_dim)) - self.bias = nn.Parameter(torch.zeros(hidden_dim)) - - self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) - else: - raise NotImplementedError - - def forward(self, graph, x): - tensor = x - if self.norm is not None and type(self.norm) != str: - return self.norm(tensor) - elif self.norm is None: - return tensor - - batch_list = graph.batch_num_nodes - batch_size = len(batch_list) - batch_list = torch.Tensor(batch_list).long().to(tensor.device) - batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) - batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor) - mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) - mean = mean.scatter_add_(0, batch_index, tensor) - mean = (mean.T / batch_list).T - mean = mean.repeat_interleave(batch_list, dim=0) - - sub = tensor - mean * self.mean_scale - - std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) - std = std.scatter_add_(0, batch_index, sub.pow(2)) - std = ((std.T / batch_list).T + 1e-6).sqrt() - std = std.repeat_interleave(batch_list, dim=0) - return self.weight * sub / std + self.bias diff --git a/trainer/utils/wget_parser.py b/trainer/utils/wget_parser.py deleted file mode 100644 index 3ffbcffda1c5e9176cf1e5219b8c1e0ed1574055..0000000000000000000000000000000000000000 --- a/trainer/utils/wget_parser.py +++ /dev/null @@ -1,820 +0,0 @@ -import argparse -import json -import os - -import xxhash -import tqdm -import logging -import networkx as nx -from tqdm import tqdm -import time -import datetime - - -valid_node_type = ['file', 'process_memory', 'task', 'mmaped_file', 'path', 'socket', 'address', 'link'] -CONSOLE_ARGUMENTS = None - - -def hashgen(l): - """Generate a single hash value from a list. @l is a list of - string values, which can be properties of a node/edge. This - function returns a single hashed integer value.""" - hasher = xxhash.xxh64() - for e in l: - hasher.update(e) - return hasher.intdigest() - - -def parse_nodes(json_string, node_map): - """Parse a CamFlow JSON string that may contain nodes ("activity" or "entity"). - Parsed nodes populate @node_map, which is a dictionary that maps the node's UID, - which is assigned by CamFlow to uniquely identify a node object, to a hashed - value (in str) which represents the 'type' of the node. """ - json_object = None - try: - # use "ignore" if non-decodeable exists in the @json_string - json_object = json.loads(json_string) - except Exception as e: - print("Exception ({}) occurred when parsing a node in JSON:".format(e)) - print(json_string) - exit(1) - if "activity" in json_object: - activity = json_object["activity"] - for uid in activity: - if not uid in node_map: # only parse unseen nodes - if "prov:type" not in activity[uid]: - # a node must have a type. - # record this issue if logging is turned on - if CONSOLE_ARGUMENTS.verbose: - logging.debug("skipping a problematic activity node with no 'prov:type': {}".format(uid)) - else: - node_map[uid] = activity[uid]["prov:type"] - - if "entity" in json_object: - entity = json_object["entity"] - for uid in entity: - if not uid in node_map: - if "prov:type" not in entity[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("skipping a problematic entity node with no 'prov:type': {}".format(uid)) - else: - node_map[uid] = entity[uid]["prov:type"] - - -def parse_all_nodes(filename, node_map): - """Parse all nodes in CamFlow data. @filename is the file path of - the CamFlow data to parse. @node_map contains the mappings of all - CamFlow nodes to their hashed attributes. """ - description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing nodes in CamFlow data from {}'.format(filename) - pb = tqdm(desc=description, mininterval=1.0, unit=" recs") - with open(filename, 'r') as f: - # each line in CamFlow data could contain multiple - # provenance nodes, we call @parse_nodes routine. - for line in f: - pb.update() # for progress tracking - parse_nodes(line, node_map) - f.close() - pb.close() - - -def parse_all_edges(inputfile, outputfile, node_map, noencode): - """Parse all edges (including their timestamp) from CamFlow data file @inputfile - to an @outputfile. Before this function is called, parse_all_nodes should be called - to populate the @node_map for all nodes in the CamFlow file. If @noencode is set, - we do not hash the nodes' original UUIDs generated by CamFlow to integers. This - function returns the total number of valid edge parsed from CamFlow dataset. - - The output edgelist has the following format for each line, if -s is not set: - <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp> - If -s is set, each line would look like: - <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>:<timestamp_stats>""" - total_edges = 0 - smallest_timestamp = None - # scan through the entire file to find the smallest timestamp from all the edges. - # this step is only needed if we need to add some statistical information. - if CONSOLE_ARGUMENTS.stats: - description = '\x1b[6;30;42m[STATUS]\x1b[0m Scanning edges in CamFlow data from {}'.format(inputfile) - pb = tqdm(desc=description, mininterval=1.0, unit=" recs") - with open(inputfile, 'r') as f: - for line in f: - pb.update() - json_object = json.loads(line) - - if "used" in json_object: - used = json_object["used"] - for uid in used: - if "prov:type" not in used[uid]: - continue - if "cf:date" not in used[uid]: - continue - if "prov:entity" not in used[uid]: - continue - if "prov:activity" not in used[uid]: - continue - srcUUID = used[uid]["prov:entity"] - dstUUID = used[uid]["prov:activity"] - if srcUUID not in node_map: - continue - if dstUUID not in node_map: - continue - timestamp_str = used[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - if smallest_timestamp == None or ts < smallest_timestamp: - smallest_timestamp = ts - - if "wasGeneratedBy" in json_object: - wasGeneratedBy = json_object["wasGeneratedBy"] - for uid in wasGeneratedBy: - if "prov:type" not in wasGeneratedBy[uid]: - continue - if "cf:date" not in wasGeneratedBy[uid]: - continue - if "prov:entity" not in wasGeneratedBy[uid]: - continue - if "prov:activity" not in wasGeneratedBy[uid]: - continue - srcUUID = wasGeneratedBy[uid]["prov:activity"] - dstUUID = wasGeneratedBy[uid]["prov:entity"] - if srcUUID not in node_map: - continue - if dstUUID not in node_map: - continue - timestamp_str = wasGeneratedBy[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - if smallest_timestamp == None or ts < smallest_timestamp: - smallest_timestamp = ts - - if "wasInformedBy" in json_object: - wasInformedBy = json_object["wasInformedBy"] - for uid in wasInformedBy: - if "prov:type" not in wasInformedBy[uid]: - continue - if "cf:date" not in wasInformedBy[uid]: - continue - if "prov:informant" not in wasInformedBy[uid]: - continue - if "prov:informed" not in wasInformedBy[uid]: - continue - srcUUID = wasInformedBy[uid]["prov:informant"] - dstUUID = wasInformedBy[uid]["prov:informed"] - if srcUUID not in node_map: - continue - if dstUUID not in node_map: - continue - timestamp_str = wasInformedBy[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - if smallest_timestamp == None or ts < smallest_timestamp: - smallest_timestamp = ts - - if "wasDerivedFrom" in json_object: - wasDerivedFrom = json_object["wasDerivedFrom"] - for uid in wasDerivedFrom: - if "prov:type" not in wasDerivedFrom[uid]: - continue - if "cf:date" not in wasDerivedFrom[uid]: - continue - if "prov:usedEntity" not in wasDerivedFrom[uid]: - continue - if "prov:generatedEntity" not in wasDerivedFrom[uid]: - continue - srcUUID = wasDerivedFrom[uid]["prov:usedEntity"] - dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"] - if srcUUID not in node_map: - continue - if dstUUID not in node_map: - continue - timestamp_str = wasDerivedFrom[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - if smallest_timestamp == None or ts < smallest_timestamp: - smallest_timestamp = ts - - if "wasAssociatedWith" in json_object: - wasAssociatedWith = json_object["wasAssociatedWith"] - for uid in wasAssociatedWith: - if "prov:type" not in wasAssociatedWith[uid]: - continue - if "cf:date" not in wasAssociatedWith[uid]: - continue - if "prov:agent" not in wasAssociatedWith[uid]: - continue - if "prov:activity" not in wasAssociatedWith[uid]: - continue - srcUUID = wasAssociatedWith[uid]["prov:agent"] - dstUUID = wasAssociatedWith[uid]["prov:activity"] - if srcUUID not in node_map: - continue - if dstUUID not in node_map: - continue - timestamp_str = wasAssociatedWith[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - if smallest_timestamp == None or ts < smallest_timestamp: - smallest_timestamp = ts - f.close() - pb.close() - - # we will go through the CamFlow data (again) and output edgelist to a file - output = open(outputfile, "w+") - description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in CamFlow data from {}'.format(inputfile) - pb = tqdm(desc=description, mininterval=1.0, unit=" recs") - with open(inputfile, 'r') as f: - for line in f: - pb.update() - json_object = json.loads(line) - - if "used" in json_object: - used = json_object["used"] - for uid in used: - if "prov:type" not in used[uid]: - # an edge must have a type; if not, - # we will have to skip the edge. Log - # this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (used) record without type: {}".format(uid)) - continue - else: - edgetype = "used" - # cf:id is used as logical timestamp to order edges - if "cf:id" not in used[uid]: - # an edge must have a logical timestamp; - # if not we will have to skip the edge. - # Log this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (used) record without logical timestamp: {}".format(uid)) - continue - else: - timestamp = used[uid]["cf:id"] - if "prov:entity" not in used[uid]: - # an edge's source node must exist; - # if not, we will have to skip the - # edge. Log this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug( - "edge (used/{}) record without source UUID: {}".format(used[uid]["prov:type"], uid)) - continue - if "prov:activity" not in used[uid]: - # an edge's destination node must exist; - # if not, we will have to skip the edge. - # Log this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug( - "edge (used/{}) record without destination UUID: {}".format(used[uid]["prov:type"], - uid)) - continue - srcUUID = used[uid]["prov:entity"] - dstUUID = used[uid]["prov:activity"] - # both source and destination node must - # exist in @node_map; if not, we will - # have to skip the edge. Log this issue - # if verbose is set. - if srcUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug( - "edge (used/{}) record with an unseen srcUUID: {}".format(used[uid]["prov:type"], uid)) - continue - else: - srcVal = node_map[srcUUID] - if dstUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug( - "edge (used/{}) record with an unseen dstUUID: {}".format(used[uid]["prov:type"], uid)) - continue - else: - dstVal = node_map[dstUUID] - if "cf:date" not in used[uid]: - # an edge must have a timestamp; if - # not, we will have to skip the edge. - # Log this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (used) record without timestamp: {}".format(uid)) - continue - else: - # we only record @adjusted_ts if we need - # to record stats of CamFlow dataset. - if CONSOLE_ARGUMENTS.stats: - ts_str = used[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - adjusted_ts = ts - smallest_timestamp - if "cf:jiffies" not in used[uid]: - # an edge must have a jiffies timestamp; if - # not, we will have to skip the edge. - # Log this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (used) record without jiffies: {}".format(uid)) - continue - else: - # we only record @jiffies if - # the option is set - if CONSOLE_ARGUMENTS.jiffies: - jiffies = used[uid]["cf:jiffies"] - total_edges += 1 - if noencode: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) - else: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp)) - - if "wasGeneratedBy" in json_object: - wasGeneratedBy = json_object["wasGeneratedBy"] - for uid in wasGeneratedBy: - if "prov:type" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy) record without type: {}".format(uid)) - continue - else: - edgetype = "wasGeneratedBy" - if "cf:id" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy) record without logical timestamp: {}".format(uid)) - continue - else: - timestamp = wasGeneratedBy[uid]["cf:id"] - if "prov:entity" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy/{}) record without source UUID: {}".format( - wasGeneratedBy[uid]["prov:type"], uid)) - continue - if "prov:activity" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy/{}) record without destination UUID: {}".format( - wasGeneratedBy[uid]["prov:type"], uid)) - continue - srcUUID = wasGeneratedBy[uid]["prov:activity"] - dstUUID = wasGeneratedBy[uid]["prov:entity"] - if srcUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy/{}) record with an unseen srcUUID: {}".format( - wasGeneratedBy[uid]["prov:type"], uid)) - continue - else: - srcVal = node_map[srcUUID] - if dstUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy/{}) record with an unsen dstUUID: {}".format( - wasGeneratedBy[uid]["prov:type"], uid)) - continue - else: - dstVal = node_map[dstUUID] - if "cf:date" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy) record without timestamp: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.stats: - ts_str = wasGeneratedBy[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - adjusted_ts = ts - smallest_timestamp - if "cf:jiffies" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy) record without jiffies: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.jiffies: - jiffies = wasGeneratedBy[uid]["cf:jiffies"] - total_edges += 1 - if noencode: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) - else: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, - edgetype, timestamp)) - - if "wasInformedBy" in json_object: - wasInformedBy = json_object["wasInformedBy"] - for uid in wasInformedBy: - if "prov:type" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy) record without type: {}".format(uid)) - continue - else: - edgetype = "wasInformedBy" - if "cf:id" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy) record without logical timestamp: {}".format(uid)) - continue - else: - timestamp = wasInformedBy[uid]["cf:id"] - if "prov:informant" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy/{}) record without source UUID: {}".format( - wasInformedBy[uid]["prov:type"], uid)) - continue - if "prov:informed" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy/{}) record without destination UUID: {}".format( - wasInformedBy[uid]["prov:type"], uid)) - continue - srcUUID = wasInformedBy[uid]["prov:informant"] - dstUUID = wasInformedBy[uid]["prov:informed"] - if srcUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy/{}) record with an unseen srcUUID: {}".format( - wasInformedBy[uid]["prov:type"], uid)) - continue - else: - srcVal = node_map[srcUUID] - if dstUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy/{}) record with an unseen dstUUID: {}".format( - wasInformedBy[uid]["prov:type"], uid)) - continue - else: - dstVal = node_map[dstUUID] - if "cf:date" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy) record without timestamp: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.stats: - ts_str = wasInformedBy[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - adjusted_ts = ts - smallest_timestamp - if "cf:jiffies" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy) record without jiffies: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.jiffies: - jiffies = wasInformedBy[uid]["cf:jiffies"] - total_edges += 1 - if noencode: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) - else: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, - edgetype, timestamp)) - - if "wasDerivedFrom" in json_object: - wasDerivedFrom = json_object["wasDerivedFrom"] - for uid in wasDerivedFrom: - if "prov:type" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom) record without type: {}".format(uid)) - continue - else: - edgetype = "wasDerivedFrom" - if "cf:id" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom) record without logical timestamp: {}".format(uid)) - continue - else: - timestamp = wasDerivedFrom[uid]["cf:id"] - if "prov:usedEntity" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom/{}) record without source UUID: {}".format( - wasDerivedFrom[uid]["prov:type"], uid)) - continue - if "prov:generatedEntity" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom/{}) record without destination UUID: {}".format( - wasDerivedFrom[uid]["prov:type"], uid)) - continue - srcUUID = wasDerivedFrom[uid]["prov:usedEntity"] - dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"] - if srcUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom/{}) record with an unseen srcUUID: {}".format( - wasDerivedFrom[uid]["prov:type"], uid)) - continue - else: - srcVal = node_map[srcUUID] - if dstUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom/{}) record with an unseen dstUUID: {}".format( - wasDerivedFrom[uid]["prov:type"], uid)) - continue - else: - dstVal = node_map[dstUUID] - if "cf:date" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom) record without timestamp: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.stats: - ts_str = wasDerivedFrom[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - adjusted_ts = ts - smallest_timestamp - if "cf:jiffies" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom) record without jiffies: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.jiffies: - jiffies = wasDerivedFrom[uid]["cf:jiffies"] - total_edges += 1 - if noencode: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) - else: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, - edgetype, timestamp)) - - if "wasAssociatedWith" in json_object: - wasAssociatedWith = json_object["wasAssociatedWith"] - for uid in wasAssociatedWith: - if "prov:type" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith) record without type: {}".format(uid)) - continue - else: - edgetype = "wasAssociatedWith" - if "cf:id" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith) record without logical timestamp: {}".format(uid)) - continue - else: - timestamp = wasAssociatedWith[uid]["cf:id"] - if "prov:agent" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith/{}) record without source UUID: {}".format( - wasAssociatedWith[uid]["prov:type"], uid)) - continue - if "prov:activity" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith/{}) record without destination UUID: {}".format( - wasAssociatedWith[uid]["prov:type"], uid)) - continue - srcUUID = wasAssociatedWith[uid]["prov:agent"] - dstUUID = wasAssociatedWith[uid]["prov:activity"] - if srcUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith/{}) record with an unseen srcUUID: {}".format( - wasAssociatedWith[uid]["prov:type"], uid)) - continue - else: - srcVal = node_map[srcUUID] - if dstUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith/{}) record with an unseen dstUUID: {}".format( - wasAssociatedWith[uid]["prov:type"], uid)) - continue - else: - dstVal = node_map[dstUUID] - if "cf:date" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith) record without timestamp: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.stats: - ts_str = wasAssociatedWith[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - adjusted_ts = ts - smallest_timestamp - if "cf:jiffies" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith) record without jiffies: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.jiffies: - jiffies = wasAssociatedWith[uid]["cf:jiffies"] - total_edges += 1 - if noencode: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) - else: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, - edgetype, timestamp)) - f.close() - output.close() - pb.close() - return total_edges - -def read_single_graph(file_name, threshold): - graph = [] - edge_cnt = 0 - with open(file_name, 'r') as f: - for line in f: - try: - edge = line.strip().split("\t") - new_edge = [edge[0], edge[1]] - attributes = edge[2].strip().split(":") - source_node_type = attributes[0] - destination_node_type = attributes[1] - edge_type = attributes[2] - edge_order = attributes[3] - - new_edge.append(source_node_type) - new_edge.append(destination_node_type) - new_edge.append(edge_type) - new_edge.append(edge_order) - graph.append(new_edge) - edge_cnt += 1 - except: - print("{}".format(line)) - f.close() - graph.sort(key=lambda e: e[5]) - if len(graph) <= threshold: - return graph - else: - return graph[:threshold] - - -def process_graph(name, threshold): - graph = read_single_graph(name, threshold) - result_graph = nx.DiGraph() - cnt = 0 - for num, edge in enumerate(graph): - cnt += 1 - src, dst, src_type, dst_type, edge_type = edge[:5] - if True:# src_type in valid_node_type and dst_type in valid_node_type: - if not result_graph.has_node(src): - result_graph.add_node(src, type=src_type) - if not result_graph.has_node(dst): - result_graph.add_node(dst, type=dst_type) - if not result_graph.has_edge(src, dst): - result_graph.add_edge(src, dst, type=edge_type) - if bidirection: - result_graph.add_edge(dst, src, type='reverse_{}'.format(edge_type)) - return cnt, result_graph - - -node_type_list = [] -edge_type_list = [] -node_type_dict = {} -edge_type_dict = {} - - -def format_graph(g, name): - new_g = nx.DiGraph() - node_map = {} - node_cnt = 0 - for n in g.nodes: - node_map[n] = node_cnt - new_g.add_node(node_cnt, type=g.nodes[n]['type']) - node_cnt += 1 - for e in g.edges: - new_g.add_edge(node_map[e[0]], node_map[e[1]], type=g.edges[e]['type']) - for n in new_g.nodes: - node_type = new_g.nodes[n]['type'] - if not node_type in node_type_dict: - node_type_list.append(node_type) - node_type_dict[node_type] = 1 - else: - node_type_dict[node_type] += 1 - for e in new_g.edges: - edge_type = new_g.edges[e]['type'] - if not edge_type in edge_type_dict: - edge_type_list.append(edge_type) - edge_type_dict[edge_type] = 1 - else: - edge_type_dict[edge_type] += 1 - for n in new_g.nodes: - new_g.nodes[n]['type'] = node_type_list.index(new_g.nodes[n]['type']) - for e in new_g.edges: - new_g.edges[e]['type'] = edge_type_list.index(new_g.edges[e]['type']) - with open('{}.json'.format(name), 'w', encoding='utf-8') as f: - json.dump(nx.node_link_data(new_g), f) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Convert CamFlow JSON to Unicorn edgelist') - args = parser.parse_args() - args.stats = False - args.verbose = False - args.jiffies = False - args.input = '../data/wget/raw/' - args.output = '../data/wget/processed/' - args.final_output = '../data/wget/final/' - args.noencode = False - if not os.path.exists(args.input): - os.mkdir(args.input) - if not os.path.exists(args.output): - os.mkdir(args.output) - if not os.path.exists(args.final_output): - os.mkdir(args.final_output) - CONSOLE_ARGUMENTS = args - - if args.verbose: - logging.basicConfig(filename=args.log, level=logging.DEBUG) - cnt = 0 - for fname in os.listdir(args.input): - cnt += 1 - node_map = dict() - parse_all_nodes(args.input + '/{}'.format(fname), node_map) - total_edges = parse_all_edges(args.input + '/{}'.format(fname), args.output + '/{}.log'.format(cnt), node_map, - args.noencode) - if args.stats: - total_nodes = len(node_map) - stats = open(args.stats_file + '/{}.log'.format(cnt), "a+") - stats.write("{},{},{}\n".format(args.input + '/{}'.format(fname), total_nodes, total_edges)) - - bidirection = False - threshold = 10000000 # infinity - interaction_dict = [] - graph_cnt = 0 - result_graphs = [] - input = args.output - base = args.final_output - - line_cnt = 0 - for i in tqdm(range(cnt)): - single_cnt, result_graph = process_graph('{}{}.log'.format(input, i + 1), threshold) - format_graph(result_graph, '{}{}'.format(base, i)) - line_cnt += single_cnt - - print(line_cnt // 150) - print(len(node_type_list)) - print(node_type_dict) - print(len(edge_type_list)) - print(edge_type_dict) - diff --git a/utils/.gitkeep b/utils/.gitkeep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/utils/__pycache__/config.cpython-310.pyc b/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca2a48894d4871479b958f92c0ae288f5c3e6cad Binary files /dev/null and b/utils/__pycache__/config.cpython-310.pyc differ diff --git a/utils/__pycache__/config.cpython-311.pyc b/utils/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a2ca013c4c0826c94b817a15f5aeaadf3fbf021 Binary files /dev/null and b/utils/__pycache__/config.cpython-311.pyc differ diff --git a/utils/__pycache__/configJupyter.cpython-311.pyc b/utils/__pycache__/configJupyter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..470b992f5214101b8dcf54596f455e87c14e0d9b Binary files /dev/null and b/utils/__pycache__/configJupyter.cpython-311.pyc differ diff --git a/utils/__pycache__/dataloader.cpython-310.pyc b/utils/__pycache__/dataloader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d87ef800e3349966c54f8aff0158c342c15ce957 Binary files /dev/null and b/utils/__pycache__/dataloader.cpython-310.pyc differ diff --git a/utils/__pycache__/dataloader.cpython-311.pyc b/utils/__pycache__/dataloader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..435191a8fbab8cf94e7ec9ff76aa01f782a87157 Binary files /dev/null and b/utils/__pycache__/dataloader.cpython-311.pyc differ diff --git a/utils/__pycache__/loaddata.cpython-311.pyc b/utils/__pycache__/loaddata.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..516f605481c393aef9331a111ab596d8f327e0ee Binary files /dev/null and b/utils/__pycache__/loaddata.cpython-311.pyc differ diff --git a/utils/__pycache__/poolers.cpython-310.pyc b/utils/__pycache__/poolers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..441c02bbdd690fbfccc5acf786b9ded0f067bd8d Binary files /dev/null and b/utils/__pycache__/poolers.cpython-310.pyc differ diff --git a/utils/__pycache__/poolers.cpython-311.pyc b/utils/__pycache__/poolers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4464f0cddb9bf9523f2ee1a43193ac41ebdb2f73 Binary files /dev/null and b/utils/__pycache__/poolers.cpython-311.pyc differ diff --git a/utils/__pycache__/utils.cpython-310.pyc b/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10658e8f670d580962f3f3982605978f526561a7 Binary files /dev/null and b/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/utils/__pycache__/utils.cpython-311.pyc b/utils/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16c0ff7b920416d5844cb0abf8995663dd0bf741 Binary files /dev/null and b/utils/__pycache__/utils.cpython-311.pyc differ diff --git a/utils/config.py b/utils/config.py index ee14920f4b2e36dfcd0d7fddca100572837ade20..75b33c7e63bcbcab20114e07ab7075922c59417b 100644 --- a/utils/config.py +++ b/utils/config.py @@ -1,13 +1,23 @@ +import argparse + + def build_args(): - args = {} - args["dataset"] ="wget" - args["device"]=-1 - args["lr"]=0.001 - args["weight_decay"]=5e-4 - args["negative_slope"]=0.2 - args["mask_rate"]=0.5 - args["alpha_l"]=3 - args["optimizer"]="adam" - args["loss_fn"]='sce' - args["pooling"] = "mean" - return args + parser = argparse.ArgumentParser(description="MAGIC") + parser.add_argument("--device", type=int, default=-1) + parser.add_argument("--lr", type=float, default=0.0001, + help="learning rate") + parser.add_argument("--weight_decay", type=float, default=5e-4, + help="weight decay") + parser.add_argument("--negative_slope", type=float, default=0.2, + help="the negative slope of leaky relu for GAT") + parser.add_argument("--mask_rate", type=float, default=0.5) + parser.add_argument("--alpha_l", type=float, default=3, help="`pow`inddex for `sce` loss") + parser.add_argument("--optimizer", type=str, default="adam") + parser.add_argument("--loss_fn", type=str, default='sce') + parser.add_argument("--pooling", type=str, default="mean") + parser.add_argument("--run_id", type=str, default="fedgraphnn_cs_ch") + parser.add_argument("--cf", type=str, default="fedml_config.yaml") + parser.add_argument("--rank", type=int, default=0) + parser.add_argument("--role", type=str, default="server") + args = parser.parse_args() + return args \ No newline at end of file diff --git a/utils/configJupyter.py b/utils/configJupyter.py new file mode 100644 index 0000000000000000000000000000000000000000..269779745a9c835b757ea725d201e11e5ee3e87b --- /dev/null +++ b/utils/configJupyter.py @@ -0,0 +1,14 @@ + +def build_args(): + + args = {} + args['device'] = -1 + args['lr'] = 0.01 + args['weight_decay'] = 5e-4 + args['negative_slope'] = 0.2 + args['mask_rate'] = 0.5 + args['alpha_l'] = 3 + args['optimizer'] = 'adam' + args['loss_fn'] = "sce" + args['pooling'] = "mean" + return args \ No newline at end of file diff --git a/utils/dataloader.py b/utils/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..992c6fcacaee228330db91e920c03c6d4f2fc8ad --- /dev/null +++ b/utils/dataloader.py @@ -0,0 +1,331 @@ +import os +import random +import networkx as nx +import dgl +import torch +import pickle as pkl +import json +import logging +import numpy as np + +path_dataset = 'D:/PFE DATASETS/' + +def darpa_split(name): + metadata = load_metadata(name) + n_train = metadata['n_train'] + train_dataset = range(n_train) + train_labels = [0]* n_train + + + return ( + train_dataset, + train_labels, + [], + [], + [], + [] + ) + + +def create_random_split(name, snapshots): + dataset = load_data(name) + # Random 80/10/10 split as suggested + + + all_idxs = list(range(len(dataset))) + random.shuffle(all_idxs) + + train_dataset = dataset['train_index'] + train_labels = [] + for id in train_dataset: + train_labels.append(dataset['labels'][id]) + + val_dataset = dataset['validation_index'] + val_labels = [] + for id in val_dataset: + val_labels.append(dataset['labels'][id]) + + test_dataset = dataset['test_index'] + test_labels = [] + for id in test_dataset: + test_labels.append(dataset['labels'][id]) + + + return ( + train_dataset, + train_labels, + val_dataset, + val_labels, + test_dataset, + test_labels, + ) + + + +def partition_data_by_sample_size( + client_number, name, snapshots +): + if (name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']): + ( + train_dataset, + train_labels, + val_dataset, + val_labels, + test_dataset, + test_labels, + ) = create_random_split(name, snapshots) + else: + ( + train_dataset, + train_labels, + val_dataset, + val_labels, + test_dataset, + test_labels, + ) = darpa_split(name) + + num_train_samples = len(train_dataset) + num_val_samples = len(val_dataset) + num_test_samples = len(test_dataset) + + train_idxs = list(range(num_train_samples)) + val_idxs = list(range(num_val_samples)) + test_idxs = list(range(num_test_samples)) + + random.shuffle(train_idxs) + random.shuffle(val_idxs) + random.shuffle(test_idxs) + + partition_dicts = [None] * client_number + + + clients_idxs_train = np.array_split(train_idxs, client_number) + clients_idxs_val = np.array_split(val_idxs, client_number) + clients_idxs_test = np.array_split(test_idxs, client_number) + + labels_of_all_clients = [] + for client in range(client_number): + client_train_idxs = clients_idxs_train[client] + client_val_idxs = clients_idxs_val[client] + client_test_idxs = clients_idxs_test[client] + + train_dataset_client = [ + train_dataset[idx] for idx in client_train_idxs + ] + train_labels_client = [train_labels[idx] for idx in client_train_idxs] + labels_of_all_clients.append(train_labels_client) + + val_dataset_client = [val_dataset[idx] for idx in client_val_idxs] + val_labels_client = [val_labels[idx] for idx in client_val_idxs] + + test_dataset_client = [test_dataset[idx] for idx in client_test_idxs] + test_labels_client = [test_labels[idx] for idx in client_test_idxs] + + + partition_dict = { + "train": train_dataset_client, + "val": val_dataset_client, + "test": test_dataset_client, + } + + partition_dicts[client] = partition_dict + global_data_dict = { + "train": train_dataset, + "val": val_dataset, + "test": test_dataset, + } + + return global_data_dict, partition_dicts + +def load_partition_data( + client_number, + name, + snapshots, + global_test=True, +): + global_data_dict, partition_dicts = partition_data_by_sample_size( + client_number, name, snapshots + ) + + data_local_num_dict = dict() + train_data_local_dict = dict() + val_data_local_dict = dict() + test_data_local_dict = dict() + + + + # IT IS VERY IMPORTANT THAT THE BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE. + train_data_global = global_data_dict["train"] + val_data_global = global_data_dict["val"] + test_data_global = global_data_dict["test"] + train_data_num = len(global_data_dict["train"]) + val_data_num = len(global_data_dict["val"]) + test_data_num = len(global_data_dict["test"]) + + for client in range(client_number): + train_dataset_client = partition_dicts[client]["train"] + val_dataset_client = partition_dicts[client]["val"] + test_dataset_client = partition_dicts[client]["test"] + + data_local_num_dict[client] = len(train_dataset_client) + train_data_local_dict[client] = train_dataset_client, + + val_data_local_dict[client] = val_dataset_client + + test_data_local_dict[client] = ( + test_data_global + if global_test + else test_dataset_client + + ) + + logging.info( + "Client idx = {}, local sample number = {}".format( + client, len(train_dataset_client) + ) + ) + + return ( + train_data_num, + val_data_num, + test_data_num, + train_data_global, + val_data_global, + test_data_global, + data_local_num_dict, + train_data_local_dict, + val_data_local_dict, + test_data_local_dict, + ) + + + + + + + +def preload_entity_level_dataset(name): + path = path_dataset + name + if os.path.exists(path + '/metadata.json'): + pass + else: + + malicious = pkl.load(open(path + '/malicious.pkl', 'rb')) + + n_train = len(os.listdir(path + '/train')) + n_test = len(os.listdir(path + '/test')) + + g = pkl.load(open(path + '/train/graph0/graph0.pkl', 'rb')) + + node_feature_dim = len(g.ndata['attr'][0]) + edge_feature_dim = len(g.edata['attr'][0]) + + metadata = { + 'node_feature_dim': node_feature_dim, + 'edge_feature_dim': edge_feature_dim, + 'malicious': malicious, + 'n_train': n_train, + 'n_test': n_test + } + with open(path + '/metadata.json', 'w', encoding='utf-8') as f: + json.dump(metadata, f) + + + +def load_metadata(name): + preload_entity_level_dataset(name) + with open( path_dataset + name + '/metadata.json', 'r', encoding='utf-8') as f: + metadata = json.load(f) + return metadata + + +def load_entity_level_dataset(name, t, n, snapshot, device): + preload_entity_level_dataset(name) + graphs = [] + for i in range(snapshot): + with open(path_dataset + name + '/' + t + '/graph{}/graph{}.pkl'.format(n, str(i)), 'rb') as f: + graphs.append(pkl.load(f).to(device)) + return graphs + + +def get_labels(name): + if (name=="wget" ): + return [1] * 25 + [0] * 125 + elif (name=="streamspot"): + return [0] * 300 + [1] * 100 + [0] * 200 + elif (name == 'SC2'): + return [0] * 125 + [1] * 25 + elif (name == 'Unicorn-Cadets'): + return [0] * 109 + [1] * 3 + elif (name == 'wget-long'): + return [0] * 100 + [1] * 5 + elif (name == 'clearscope-e3'): + return [0] * 44 + [1] * 50 + +def load_data(name): + if name == "wget": + n, n_dim, e_dim = 150, 14, 4 + full_dataset_index = list(range(n)) + train_dataset = list(range(50, 150)) + validation_dataset = list(range(50)) + test_dataset = list(range(50)) + elif name == "streamspot": + n, n_dim, e_dim = 600, 8, 26 + full_dataset_index = list(range(n)) + train_dataset = list(range(300)) + validation_dataset = list(range(300, 350)) + list(range(500,550)) + test_dataset = list(range(300, 400)) + list(range(400,500))+ list(range(500,600)) + elif name == 'SC2': + n_dim = len(pkl.load(open(path_dataset + 'SC2/node.pkl', 'rb')).keys()) + e_dim = len(pkl.load(open(path_dataset + 'SC2/edge.pkl', 'rb')).keys()) + n, full_dataset_index = 150, list(range(150)) + train_dataset = list(range(100)) + validation_dataset = list(range(100, 150)) + test_dataset = list(range(100, 150)) + elif name in ['Unicorn-Cadets', 'wget-long', 'clearscope-e3']: + n_dim = len(pkl.load(open(path_dataset + '{}/node.pkl'.format(name), 'rb')).keys()) + e_dim = len(pkl.load(open(path_dataset + '{}/edge.pkl'.format(name), 'rb')).keys()) + if name == 'Unicorn-Cadets': + n, train_dataset = 112, list(range(70)) + elif name == 'wget-long': + n, train_dataset = 105, list(range(70)) + else: + n, train_dataset = 94, list(range(30)) + full_dataset_index = list(range(n)) + validation_dataset = list(range(train_dataset[-1], n)) + test_dataset = validation_dataset + return {'dataset': full_dataset_index, + 'train_index': train_dataset, + 'test_index': test_dataset, + 'validation_index': validation_dataset, + 'full_index': full_dataset_index, + 'n_feat': n_dim, + 'e_feat': e_dim, + 'labels': get_labels(name)} + + + +def load_graph(id, name ,device): + graphs = [] + + if (name == "wget"): + path = path_dataset + 'wget/cache/' + 'graph{}'.format(str(id)) + elif (name == "streamspot"): + path = path_dataset + 'streamspot/cache/' + 'graph{}'.format(str(id)) + elif (name == "SC2"): + if (id < 125): path = path_dataset + 'SC2/cache/benign/' + 'graph{}'.format(str(id)) + else: path = path_dataset + 'SC2/cache/attack/' + 'graph{}'.format(str(id - 125)) + elif (name == 'Unicorn-Cadets'): + if (id < 109): path = path_dataset + 'Unicorn-Cadets/cache/benign/' + 'graph{}'.format(str(id)) + else: path = path_dataset + 'Unicorn-Cadets/cache/attack/' + 'graph{}'.format(str(id - 109)) + elif (name == 'wget-long'): + if (id < 100): path = path_dataset + 'wget-long/cache/benign/' + 'graph{}'.format(str(id)) + else: path = path_dataset + 'wget-long/cache/attack/' + 'graph{}'.format(str(id - 100)) + elif (name == 'clearscope-e3'): + if (id < 44): path = path_dataset + 'clearscope-e3/cache/benign/' + 'graph{}'.format(str(id)) + else: path = path_dataset + 'clearscope-e3/cache/attack/' + 'graph{}'.format(str(id - 44)) + + for fname in os.listdir(path): + graphs.append(pkl.load(open(path + '/' + fname, 'rb')).to(device)) + + return graphs \ No newline at end of file diff --git a/utils/loaddata.py b/utils/loaddata.py deleted file mode 100644 index ca48e0fd60aa69712f48c06f8083b37f679cf797..0000000000000000000000000000000000000000 --- a/utils/loaddata.py +++ /dev/null @@ -1,207 +0,0 @@ -import pickle as pkl -import time -import torch.nn.functional as F -import dgl -import networkx as nx -import json -from tqdm import tqdm -import os - - - -class StreamspotDataset(dgl.data.DGLDataset): - def process(self): - pass - - def __init__(self, name): - super(StreamspotDataset, self).__init__(name=name) - if name == 'streamspot': - path = './data/streamspot' - num_graphs = 600 - self.graphs = [] - self.labels = [] - print('Loading {} dataset...'.format(name)) - for i in tqdm(range(num_graphs)): - idx = i - g = dgl.from_networkx( - nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))), - node_attrs=['type'], - edge_attrs=['type'] - ) - self.graphs.append(g) - if 300 <= idx <= 399: - self.labels.append(1) - else: - self.labels.append(0) - else: - raise NotImplementedError - - def __getitem__(self, i): - return self.graphs[i], self.labels[i] - - def __len__(self): - return len(self.graphs) - - -class WgetDataset(dgl.data.DGLDataset): - def process(self): - pass - - def __init__(self, name): - super(WgetDataset, self).__init__(name=name) - if name == 'wget': - pathattack = '/data/wget/finalattack' - pathbenin = 'data/wget/finalbenin' - num_graphs_benin = 125 - num_graphs_attack = 25 - self.graphs = [] - self.labels = [] - print('Loading {} dataset...'.format(name)) - for i in tqdm(range(num_graphs_benin)): - idx = i - g = dgl.from_networkx( - nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))), - node_attrs=['type'], - edge_attrs=['type'] - ) - self.graphs.append(g) - self.labels.append(0) - - for i in tqdm(range(num_graphs_attack)): - idx = i - g = dgl.from_networkx( - nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))), - node_attrs=['type'], - edge_attrs=['type'] - ) - self.graphs.append(g) - self.labels.append(1) - else: - raise NotImplementedError - - def __getitem__(self, i): - return self.graphs[i], self.labels[i] - - def __len__(self): - return len(self.graphs) - - -def load_rawdata(name): - if name == 'streamspot': - path = './data/streamspot' - if os.path.exists(path + '/graphs.pkl'): - print('Loading processed {} dataset...'.format(name)) - raw_data = pkl.load(open(path + '/graphs.pkl', 'rb')) - else: - raw_data = StreamspotDataset(name) - pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb')) - elif name == 'wget': - path = './data/wget' - if os.path.exists(path + '/graphs.pkl'): - print('Loading processed {} dataset...'.format(name)) - raw_data = pkl.load(open(path + '/graphs.pkl', 'rb')) - else: - raw_data = WgetDataset(name) - pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb')) - else: - raise NotImplementedError - return raw_data - - -def load_batch_level_dataset(dataset_name): - dataset = load_rawdata(dataset_name) - graph, _ = dataset[0] - node_feature_dim = 0 - for g, _ in dataset: - node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item()) - edge_feature_dim = 0 - for g, _ in dataset: - edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item()) - node_feature_dim += 1 - edge_feature_dim += 1 - full_dataset = [i for i in range(len(dataset))] - train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0] - print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim)) - - return {'dataset': dataset, - 'train_index': train_dataset, - 'full_index': full_dataset, - 'n_feat': node_feature_dim, - 'e_feat': edge_feature_dim} - - -def transform_graph(g, node_feature_dim, edge_feature_dim): - new_g = g.clone() - new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float() - new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float() - return new_g - - -def preload_entity_level_dataset(path): - path = './data/' + path - if os.path.exists(path + '/metadata.json'): - pass - else: - print('transforming') - train_gs = [dgl.from_networkx( - nx.node_link_graph(g), - node_attrs=['type'], - edge_attrs=['type'] - ) for g in pkl.load(open(path + '/train.pkl', 'rb'))] - print('transforming') - test_gs = [dgl.from_networkx( - nx.node_link_graph(g), - node_attrs=['type'], - edge_attrs=['type'] - ) for g in pkl.load(open(path + '/test.pkl', 'rb'))] - malicious = pkl.load(open(path + '/malicious.pkl', 'rb')) - - node_feature_dim = 0 - for g in train_gs: - node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim) - for g in test_gs: - node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim) - node_feature_dim += 1 - edge_feature_dim = 0 - for g in train_gs: - edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim) - for g in test_gs: - edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim) - edge_feature_dim += 1 - result_test_gs = [] - for g in test_gs: - g = transform_graph(g, node_feature_dim, edge_feature_dim) - result_test_gs.append(g) - result_train_gs = [] - for g in train_gs: - g = transform_graph(g, node_feature_dim, edge_feature_dim) - result_train_gs.append(g) - metadata = { - 'node_feature_dim': node_feature_dim, - 'edge_feature_dim': edge_feature_dim, - 'malicious': malicious, - 'n_train': len(result_train_gs), - 'n_test': len(result_test_gs) - } - with open(path + '/metadata.json', 'w', encoding='utf-8') as f: - json.dump(metadata, f) - for i, g in enumerate(result_train_gs): - with open(path + '/train{}.pkl'.format(i), 'wb') as f: - pkl.dump(g, f) - for i, g in enumerate(result_test_gs): - with open(path + '/test{}.pkl'.format(i), 'wb') as f: - pkl.dump(g, f) - - -def load_metadata(path): - preload_entity_level_dataset(path) - with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f: - metadata = json.load(f) - return metadata - - -def load_entity_level_dataset(path, t, n): - preload_entity_level_dataset(path) - with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f: - data = pkl.load(f) - return data diff --git a/utils/streamspot_parser.py b/utils/streamspot_parser.py deleted file mode 100644 index 07687828ea5ecf56f9859955e599afb584826f1f..0000000000000000000000000000000000000000 --- a/utils/streamspot_parser.py +++ /dev/null @@ -1,55 +0,0 @@ -import networkx as nx -from tqdm import tqdm -import json -raw_path = '../data/streamspot/' - -NUM_GRAPHS = 600 -node_type_dict = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] -edge_type_dict = ['i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', - 'q', 't', 'u', 'v', 'w', 'y', 'z', 'A', 'C', 'D', 'E', 'G'] -node_type_set = set(node_type_dict) -edge_type_set = set(edge_type_dict) -print(raw_path) -count_graph = 0 -with open(raw_path + 'all.tsv', 'r', encoding='utf-8') as f: - print("Reading") - lines = f.readlines() - print("starting") - g = nx.DiGraph() - node_map = {} - count_node = 0 - for line in tqdm(lines): - src, src_type, dst, dst_type, etype, graph_id = line.strip('\n').split('\t') - graph_id = int(graph_id) - if src_type not in node_type_set or dst_type not in node_type_set: - continue - if etype not in edge_type_set: - continue - if graph_id != count_graph: - count_graph += 1 - for n in g.nodes(): - g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type']) - for e in g.edges(): - g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type']) - f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8') - json.dump(nx.node_link_data(g), f1) - assert graph_id == count_graph - g = nx.DiGraph() - count_node = 0 - if src not in node_map: - node_map[src] = count_node - g.add_node(count_node, type=src_type) - count_node += 1 - if dst not in node_map: - node_map[dst] = count_node - g.add_node(count_node, type=dst_type) - count_node += 1 - if not g.has_edge(node_map[src], node_map[dst]): - g.add_edge(node_map[src], node_map[dst], type=etype) - count_graph += 1 - for n in g.nodes(): - g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type']) - for e in g.edges(): - g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type']) - f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8') - json.dump(nx.node_link_data(g), f1) diff --git a/utils/trace_parser.py b/utils/trace_parser.py deleted file mode 100644 index 76a91d424a3eea0bbec87ffb635a0d98b27d72f7..0000000000000000000000000000000000000000 --- a/utils/trace_parser.py +++ /dev/null @@ -1,288 +0,0 @@ -import argparse -import json -import os -import random -import re - -from tqdm import tqdm -import networkx as nx -import pickle as pkl - - -node_type_dict = {} -edge_type_dict = {} -node_type_cnt = 0 -edge_type_cnt = 0 - -metadata = { - 'trace':{ - 'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'], - 'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4'] - }, - 'theia':{ - 'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'], - 'test': ['ta1-theia-e3-official-6r.json.8'] - }, - 'cadets':{ - 'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'], - 'test': ['ta1-cadets-e3-official-2.json'] - } -} - - -pattern_uuid = re.compile(r'uuid\":\"(.*?)\"') -pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') -pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') -pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') -pattern_type = re.compile(r'type\":\"(.*?)\"') -pattern_time = re.compile(r'timestampNanos\":(.*?),') -pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"') -pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"') -pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"') - -adversarial = 'None' - - -def read_single_graph(dataset, malicious, path, test=False): - global node_type_cnt, edge_type_cnt - g = nx.DiGraph() - print('converting {} ...'.format(path)) - path = '../data/{}/'.format(dataset) + path + '.txt' - f = open(path, 'r') - lines = [] - for l in f.readlines(): - split_line = l.split('\t') - src, src_type, dst, dst_type, edge_type, ts = split_line - ts = int(ts) - if not test: - if src in malicious or dst in malicious: - if src in malicious and src_type != 'MemoryObject': - continue - if dst in malicious and dst_type != 'MemoryObject': - continue - - if src_type not in node_type_dict: - node_type_dict[src_type] = node_type_cnt - node_type_cnt += 1 - if dst_type not in node_type_dict: - node_type_dict[dst_type] = node_type_cnt - node_type_cnt += 1 - if edge_type not in edge_type_dict: - edge_type_dict[edge_type] = edge_type_cnt - edge_type_cnt += 1 - if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type: - lines.append([dst, src, dst_type, src_type, edge_type, ts]) - else: - lines.append([src, dst, src_type, dst_type, edge_type, ts]) - lines.sort(key=lambda l: l[5]) - - node_map = {} - node_type_map = {} - node_cnt = 0 - node_list = [] - for l in lines: - src, dst, src_type, dst_type, edge_type = l[:5] - src_type_id = node_type_dict[src_type] - dst_type_id = node_type_dict[dst_type] - edge_type_id = edge_type_dict[edge_type] - if src not in node_map: - if test: - if adversarial == 'MFE' or adversarial == 'MCE': - if src in malicious: - src_type_id = node_type_dict['FILE_OBJECT_FILE'] - if not test: - if adversarial == 'BFP': - if src not in malicious: - i = random.randint(1, 20) - if i == 1: - src_type_id = node_type_dict['NetFlowObject'] - node_map[src] = node_cnt - g.add_node(node_cnt, type=src_type_id) - node_list.append(src) - node_type_map[src] = src_type - node_cnt += 1 - if dst not in node_map: - if test: - if adversarial == 'MFE' or adversarial == 'MCE': - if dst in malicious: - dst_type_id = node_type_dict['FILE_OBJECT_FILE'] - if not test: - if adversarial == 'BFP': - if dst not in malicious: - i = random.randint(1, 20) - if i == 1: - dst_type_id = node_type_dict['NetFlowObject'] - node_map[dst] = node_cnt - g.add_node(node_cnt, type=dst_type_id) - node_type_map[dst] = dst_type - node_list.append(dst) - node_cnt += 1 - if not g.has_edge(node_map[src], node_map[dst]): - g.add_edge(node_map[src], node_map[dst], type=edge_type_id) - if (adversarial == 'MSE' or adversarial == 'MCE') and test: - for i, node in enumerate(node_list): - if node in malicious: - while True: - another_node = random.choice(node_list) - if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS': - g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE']) - print(node, another_node) - break - if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE': - g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ']) - print(node, another_node) - break - if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS': - g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT']) - print(node, another_node) - break - if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject': - break - return node_map, g - - -def preprocess_dataset(dataset): - id_nodetype_map = {} - id_nodename_map = {} - for file in os.listdir('../data/{}/'.format(dataset)): - if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file: - print('reading {} ...'.format(file)) - f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8') - for line in tqdm(f): - if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue - if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue - if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue - if len(pattern_uuid.findall(line)) == 0: print(line) - uuid = pattern_uuid.findall(line)[0] - subject_type = pattern_type.findall(line) - - if len(subject_type) < 1: - if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line: - subject_type = 'MemoryObject' - if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line: - subject_type = 'NetFlowObject' - if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line: - subject_type = 'UnnamedPipeObject' - else: - subject_type = subject_type[0] - - if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']: - continue - id_nodetype_map[uuid] = subject_type - if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0: - id_nodename_map[uuid] = pattern_file_name.findall(line)[0] - elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0: - id_nodename_map[uuid] = pattern_process_name.findall(line)[0] - elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0: - id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0] - for key in metadata[dataset]: - for file in metadata[dataset][key]: - if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'): - continue - f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8') - fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8') - print('processing {} ...'.format(file)) - for line in tqdm(f): - if 'com.bbn.tc.schema.avro.cdm18.Event' in line: - edgeType = pattern_type.findall(line)[0] - timestamp = pattern_time.findall(line)[0] - srcId = pattern_src.findall(line) - - if len(srcId) == 0: continue - srcId = srcId[0] - if not srcId in id_nodetype_map: - continue - srcType = id_nodetype_map[srcId] - dstId1 = pattern_dst1.findall(line) - if len(dstId1) > 0 and dstId1[0] != 'null': - dstId1 = dstId1[0] - if not dstId1 in id_nodetype_map: - continue - dstType1 = id_nodetype_map[dstId1] - this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str( - dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n' - fw.write(this_edge1) - - dstId2 = pattern_dst2.findall(line) - if len(dstId2) > 0 and dstId2[0] != 'null': - dstId2 = dstId2[0] - if not dstId2 in id_nodetype_map.keys(): - continue - dstType2 = id_nodetype_map[dstId2] - this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str( - dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n' - fw.write(this_edge2) - fw.close() - f.close() - if len(id_nodename_map) != 0: - fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8') - json.dump(id_nodename_map, fw) - if len(id_nodetype_map) != 0: - fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8') - json.dump(id_nodetype_map, fw) - - -def read_graphs(dataset): - malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset) - f = open(malicious_entities, 'r') - malicious_entities = set() - for l in f.readlines(): - malicious_entities.add(l.lstrip().rstrip()) - - preprocess_dataset(dataset) - train_gs = [] - for file in metadata[dataset]['train']: - _, train_g = read_single_graph(dataset, malicious_entities, file, False) - train_gs.append(train_g) - test_gs = [] - test_node_map = {} - count_node = 0 - for file in metadata[dataset]['test']: - node_map, test_g = read_single_graph(dataset, malicious_entities, file, True) - assert len(node_map) == test_g.number_of_nodes() - test_gs.append(test_g) - for key in node_map: - if key not in test_node_map: - test_node_map[key] = node_map[key] + count_node - count_node += test_g.number_of_nodes() - - if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)): - with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f: - id_nodename_map = json.load(f) - with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f: - id_nodetype_map = json.load(f) - f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8') - final_malicious_entities = [] - malicious_names = [] - for e in malicious_entities: - if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject': - final_malicious_entities.append(test_node_map[e]) - if e in id_nodename_map: - malicious_names.append(id_nodename_map[e]) - f.write('{}\t{}\n'.format(e, id_nodename_map[e])) - else: - malicious_names.append(e) - f.write('{}\t{}\n'.format(e, e)) - else: - f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8') - final_malicious_entities = [] - malicious_names = [] - for e in malicious_entities: - if e in test_node_map: - final_malicious_entities.append(test_node_map[e]) - malicious_names.append(e) - f.write('{}\t{}\n'.format(e, e)) - - pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb')) - pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb')) - pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb')) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='CDM Parser') - parser.add_argument("--dataset", type=str, default="trace") - args = parser.parse_args() - if args.dataset not in ['trace', 'theia', 'cadets']: - raise NotImplementedError - read_graphs(args.dataset) - diff --git a/utils/wget_parser.py b/utils/wget_parser.py deleted file mode 100644 index 3ffbcffda1c5e9176cf1e5219b8c1e0ed1574055..0000000000000000000000000000000000000000 --- a/utils/wget_parser.py +++ /dev/null @@ -1,820 +0,0 @@ -import argparse -import json -import os - -import xxhash -import tqdm -import logging -import networkx as nx -from tqdm import tqdm -import time -import datetime - - -valid_node_type = ['file', 'process_memory', 'task', 'mmaped_file', 'path', 'socket', 'address', 'link'] -CONSOLE_ARGUMENTS = None - - -def hashgen(l): - """Generate a single hash value from a list. @l is a list of - string values, which can be properties of a node/edge. This - function returns a single hashed integer value.""" - hasher = xxhash.xxh64() - for e in l: - hasher.update(e) - return hasher.intdigest() - - -def parse_nodes(json_string, node_map): - """Parse a CamFlow JSON string that may contain nodes ("activity" or "entity"). - Parsed nodes populate @node_map, which is a dictionary that maps the node's UID, - which is assigned by CamFlow to uniquely identify a node object, to a hashed - value (in str) which represents the 'type' of the node. """ - json_object = None - try: - # use "ignore" if non-decodeable exists in the @json_string - json_object = json.loads(json_string) - except Exception as e: - print("Exception ({}) occurred when parsing a node in JSON:".format(e)) - print(json_string) - exit(1) - if "activity" in json_object: - activity = json_object["activity"] - for uid in activity: - if not uid in node_map: # only parse unseen nodes - if "prov:type" not in activity[uid]: - # a node must have a type. - # record this issue if logging is turned on - if CONSOLE_ARGUMENTS.verbose: - logging.debug("skipping a problematic activity node with no 'prov:type': {}".format(uid)) - else: - node_map[uid] = activity[uid]["prov:type"] - - if "entity" in json_object: - entity = json_object["entity"] - for uid in entity: - if not uid in node_map: - if "prov:type" not in entity[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("skipping a problematic entity node with no 'prov:type': {}".format(uid)) - else: - node_map[uid] = entity[uid]["prov:type"] - - -def parse_all_nodes(filename, node_map): - """Parse all nodes in CamFlow data. @filename is the file path of - the CamFlow data to parse. @node_map contains the mappings of all - CamFlow nodes to their hashed attributes. """ - description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing nodes in CamFlow data from {}'.format(filename) - pb = tqdm(desc=description, mininterval=1.0, unit=" recs") - with open(filename, 'r') as f: - # each line in CamFlow data could contain multiple - # provenance nodes, we call @parse_nodes routine. - for line in f: - pb.update() # for progress tracking - parse_nodes(line, node_map) - f.close() - pb.close() - - -def parse_all_edges(inputfile, outputfile, node_map, noencode): - """Parse all edges (including their timestamp) from CamFlow data file @inputfile - to an @outputfile. Before this function is called, parse_all_nodes should be called - to populate the @node_map for all nodes in the CamFlow file. If @noencode is set, - we do not hash the nodes' original UUIDs generated by CamFlow to integers. This - function returns the total number of valid edge parsed from CamFlow dataset. - - The output edgelist has the following format for each line, if -s is not set: - <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp> - If -s is set, each line would look like: - <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>:<timestamp_stats>""" - total_edges = 0 - smallest_timestamp = None - # scan through the entire file to find the smallest timestamp from all the edges. - # this step is only needed if we need to add some statistical information. - if CONSOLE_ARGUMENTS.stats: - description = '\x1b[6;30;42m[STATUS]\x1b[0m Scanning edges in CamFlow data from {}'.format(inputfile) - pb = tqdm(desc=description, mininterval=1.0, unit=" recs") - with open(inputfile, 'r') as f: - for line in f: - pb.update() - json_object = json.loads(line) - - if "used" in json_object: - used = json_object["used"] - for uid in used: - if "prov:type" not in used[uid]: - continue - if "cf:date" not in used[uid]: - continue - if "prov:entity" not in used[uid]: - continue - if "prov:activity" not in used[uid]: - continue - srcUUID = used[uid]["prov:entity"] - dstUUID = used[uid]["prov:activity"] - if srcUUID not in node_map: - continue - if dstUUID not in node_map: - continue - timestamp_str = used[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - if smallest_timestamp == None or ts < smallest_timestamp: - smallest_timestamp = ts - - if "wasGeneratedBy" in json_object: - wasGeneratedBy = json_object["wasGeneratedBy"] - for uid in wasGeneratedBy: - if "prov:type" not in wasGeneratedBy[uid]: - continue - if "cf:date" not in wasGeneratedBy[uid]: - continue - if "prov:entity" not in wasGeneratedBy[uid]: - continue - if "prov:activity" not in wasGeneratedBy[uid]: - continue - srcUUID = wasGeneratedBy[uid]["prov:activity"] - dstUUID = wasGeneratedBy[uid]["prov:entity"] - if srcUUID not in node_map: - continue - if dstUUID not in node_map: - continue - timestamp_str = wasGeneratedBy[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - if smallest_timestamp == None or ts < smallest_timestamp: - smallest_timestamp = ts - - if "wasInformedBy" in json_object: - wasInformedBy = json_object["wasInformedBy"] - for uid in wasInformedBy: - if "prov:type" not in wasInformedBy[uid]: - continue - if "cf:date" not in wasInformedBy[uid]: - continue - if "prov:informant" not in wasInformedBy[uid]: - continue - if "prov:informed" not in wasInformedBy[uid]: - continue - srcUUID = wasInformedBy[uid]["prov:informant"] - dstUUID = wasInformedBy[uid]["prov:informed"] - if srcUUID not in node_map: - continue - if dstUUID not in node_map: - continue - timestamp_str = wasInformedBy[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - if smallest_timestamp == None or ts < smallest_timestamp: - smallest_timestamp = ts - - if "wasDerivedFrom" in json_object: - wasDerivedFrom = json_object["wasDerivedFrom"] - for uid in wasDerivedFrom: - if "prov:type" not in wasDerivedFrom[uid]: - continue - if "cf:date" not in wasDerivedFrom[uid]: - continue - if "prov:usedEntity" not in wasDerivedFrom[uid]: - continue - if "prov:generatedEntity" not in wasDerivedFrom[uid]: - continue - srcUUID = wasDerivedFrom[uid]["prov:usedEntity"] - dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"] - if srcUUID not in node_map: - continue - if dstUUID not in node_map: - continue - timestamp_str = wasDerivedFrom[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - if smallest_timestamp == None or ts < smallest_timestamp: - smallest_timestamp = ts - - if "wasAssociatedWith" in json_object: - wasAssociatedWith = json_object["wasAssociatedWith"] - for uid in wasAssociatedWith: - if "prov:type" not in wasAssociatedWith[uid]: - continue - if "cf:date" not in wasAssociatedWith[uid]: - continue - if "prov:agent" not in wasAssociatedWith[uid]: - continue - if "prov:activity" not in wasAssociatedWith[uid]: - continue - srcUUID = wasAssociatedWith[uid]["prov:agent"] - dstUUID = wasAssociatedWith[uid]["prov:activity"] - if srcUUID not in node_map: - continue - if dstUUID not in node_map: - continue - timestamp_str = wasAssociatedWith[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - if smallest_timestamp == None or ts < smallest_timestamp: - smallest_timestamp = ts - f.close() - pb.close() - - # we will go through the CamFlow data (again) and output edgelist to a file - output = open(outputfile, "w+") - description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in CamFlow data from {}'.format(inputfile) - pb = tqdm(desc=description, mininterval=1.0, unit=" recs") - with open(inputfile, 'r') as f: - for line in f: - pb.update() - json_object = json.loads(line) - - if "used" in json_object: - used = json_object["used"] - for uid in used: - if "prov:type" not in used[uid]: - # an edge must have a type; if not, - # we will have to skip the edge. Log - # this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (used) record without type: {}".format(uid)) - continue - else: - edgetype = "used" - # cf:id is used as logical timestamp to order edges - if "cf:id" not in used[uid]: - # an edge must have a logical timestamp; - # if not we will have to skip the edge. - # Log this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (used) record without logical timestamp: {}".format(uid)) - continue - else: - timestamp = used[uid]["cf:id"] - if "prov:entity" not in used[uid]: - # an edge's source node must exist; - # if not, we will have to skip the - # edge. Log this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug( - "edge (used/{}) record without source UUID: {}".format(used[uid]["prov:type"], uid)) - continue - if "prov:activity" not in used[uid]: - # an edge's destination node must exist; - # if not, we will have to skip the edge. - # Log this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug( - "edge (used/{}) record without destination UUID: {}".format(used[uid]["prov:type"], - uid)) - continue - srcUUID = used[uid]["prov:entity"] - dstUUID = used[uid]["prov:activity"] - # both source and destination node must - # exist in @node_map; if not, we will - # have to skip the edge. Log this issue - # if verbose is set. - if srcUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug( - "edge (used/{}) record with an unseen srcUUID: {}".format(used[uid]["prov:type"], uid)) - continue - else: - srcVal = node_map[srcUUID] - if dstUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug( - "edge (used/{}) record with an unseen dstUUID: {}".format(used[uid]["prov:type"], uid)) - continue - else: - dstVal = node_map[dstUUID] - if "cf:date" not in used[uid]: - # an edge must have a timestamp; if - # not, we will have to skip the edge. - # Log this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (used) record without timestamp: {}".format(uid)) - continue - else: - # we only record @adjusted_ts if we need - # to record stats of CamFlow dataset. - if CONSOLE_ARGUMENTS.stats: - ts_str = used[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - adjusted_ts = ts - smallest_timestamp - if "cf:jiffies" not in used[uid]: - # an edge must have a jiffies timestamp; if - # not, we will have to skip the edge. - # Log this issue if verbose is set. - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (used) record without jiffies: {}".format(uid)) - continue - else: - # we only record @jiffies if - # the option is set - if CONSOLE_ARGUMENTS.jiffies: - jiffies = used[uid]["cf:jiffies"] - total_edges += 1 - if noencode: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) - else: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp)) - - if "wasGeneratedBy" in json_object: - wasGeneratedBy = json_object["wasGeneratedBy"] - for uid in wasGeneratedBy: - if "prov:type" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy) record without type: {}".format(uid)) - continue - else: - edgetype = "wasGeneratedBy" - if "cf:id" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy) record without logical timestamp: {}".format(uid)) - continue - else: - timestamp = wasGeneratedBy[uid]["cf:id"] - if "prov:entity" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy/{}) record without source UUID: {}".format( - wasGeneratedBy[uid]["prov:type"], uid)) - continue - if "prov:activity" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy/{}) record without destination UUID: {}".format( - wasGeneratedBy[uid]["prov:type"], uid)) - continue - srcUUID = wasGeneratedBy[uid]["prov:activity"] - dstUUID = wasGeneratedBy[uid]["prov:entity"] - if srcUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy/{}) record with an unseen srcUUID: {}".format( - wasGeneratedBy[uid]["prov:type"], uid)) - continue - else: - srcVal = node_map[srcUUID] - if dstUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy/{}) record with an unsen dstUUID: {}".format( - wasGeneratedBy[uid]["prov:type"], uid)) - continue - else: - dstVal = node_map[dstUUID] - if "cf:date" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy) record without timestamp: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.stats: - ts_str = wasGeneratedBy[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - adjusted_ts = ts - smallest_timestamp - if "cf:jiffies" not in wasGeneratedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasGeneratedBy) record without jiffies: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.jiffies: - jiffies = wasGeneratedBy[uid]["cf:jiffies"] - total_edges += 1 - if noencode: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) - else: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, - edgetype, timestamp)) - - if "wasInformedBy" in json_object: - wasInformedBy = json_object["wasInformedBy"] - for uid in wasInformedBy: - if "prov:type" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy) record without type: {}".format(uid)) - continue - else: - edgetype = "wasInformedBy" - if "cf:id" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy) record without logical timestamp: {}".format(uid)) - continue - else: - timestamp = wasInformedBy[uid]["cf:id"] - if "prov:informant" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy/{}) record without source UUID: {}".format( - wasInformedBy[uid]["prov:type"], uid)) - continue - if "prov:informed" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy/{}) record without destination UUID: {}".format( - wasInformedBy[uid]["prov:type"], uid)) - continue - srcUUID = wasInformedBy[uid]["prov:informant"] - dstUUID = wasInformedBy[uid]["prov:informed"] - if srcUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy/{}) record with an unseen srcUUID: {}".format( - wasInformedBy[uid]["prov:type"], uid)) - continue - else: - srcVal = node_map[srcUUID] - if dstUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy/{}) record with an unseen dstUUID: {}".format( - wasInformedBy[uid]["prov:type"], uid)) - continue - else: - dstVal = node_map[dstUUID] - if "cf:date" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy) record without timestamp: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.stats: - ts_str = wasInformedBy[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - adjusted_ts = ts - smallest_timestamp - if "cf:jiffies" not in wasInformedBy[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasInformedBy) record without jiffies: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.jiffies: - jiffies = wasInformedBy[uid]["cf:jiffies"] - total_edges += 1 - if noencode: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) - else: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, - edgetype, timestamp)) - - if "wasDerivedFrom" in json_object: - wasDerivedFrom = json_object["wasDerivedFrom"] - for uid in wasDerivedFrom: - if "prov:type" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom) record without type: {}".format(uid)) - continue - else: - edgetype = "wasDerivedFrom" - if "cf:id" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom) record without logical timestamp: {}".format(uid)) - continue - else: - timestamp = wasDerivedFrom[uid]["cf:id"] - if "prov:usedEntity" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom/{}) record without source UUID: {}".format( - wasDerivedFrom[uid]["prov:type"], uid)) - continue - if "prov:generatedEntity" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom/{}) record without destination UUID: {}".format( - wasDerivedFrom[uid]["prov:type"], uid)) - continue - srcUUID = wasDerivedFrom[uid]["prov:usedEntity"] - dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"] - if srcUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom/{}) record with an unseen srcUUID: {}".format( - wasDerivedFrom[uid]["prov:type"], uid)) - continue - else: - srcVal = node_map[srcUUID] - if dstUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom/{}) record with an unseen dstUUID: {}".format( - wasDerivedFrom[uid]["prov:type"], uid)) - continue - else: - dstVal = node_map[dstUUID] - if "cf:date" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom) record without timestamp: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.stats: - ts_str = wasDerivedFrom[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - adjusted_ts = ts - smallest_timestamp - if "cf:jiffies" not in wasDerivedFrom[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasDerivedFrom) record without jiffies: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.jiffies: - jiffies = wasDerivedFrom[uid]["cf:jiffies"] - total_edges += 1 - if noencode: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) - else: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, - edgetype, timestamp)) - - if "wasAssociatedWith" in json_object: - wasAssociatedWith = json_object["wasAssociatedWith"] - for uid in wasAssociatedWith: - if "prov:type" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith) record without type: {}".format(uid)) - continue - else: - edgetype = "wasAssociatedWith" - if "cf:id" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith) record without logical timestamp: {}".format(uid)) - continue - else: - timestamp = wasAssociatedWith[uid]["cf:id"] - if "prov:agent" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith/{}) record without source UUID: {}".format( - wasAssociatedWith[uid]["prov:type"], uid)) - continue - if "prov:activity" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith/{}) record without destination UUID: {}".format( - wasAssociatedWith[uid]["prov:type"], uid)) - continue - srcUUID = wasAssociatedWith[uid]["prov:agent"] - dstUUID = wasAssociatedWith[uid]["prov:activity"] - if srcUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith/{}) record with an unseen srcUUID: {}".format( - wasAssociatedWith[uid]["prov:type"], uid)) - continue - else: - srcVal = node_map[srcUUID] - if dstUUID not in node_map: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith/{}) record with an unseen dstUUID: {}".format( - wasAssociatedWith[uid]["prov:type"], uid)) - continue - else: - dstVal = node_map[dstUUID] - if "cf:date" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith) record without timestamp: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.stats: - ts_str = wasAssociatedWith[uid]["cf:date"] - ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) - adjusted_ts = ts - smallest_timestamp - if "cf:jiffies" not in wasAssociatedWith[uid]: - if CONSOLE_ARGUMENTS.verbose: - logging.debug("edge (wasAssociatedWith) record without jiffies: {}".format(uid)) - continue - else: - if CONSOLE_ARGUMENTS.jiffies: - jiffies = wasAssociatedWith[uid]["cf:jiffies"] - total_edges += 1 - if noencode: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) - else: - if CONSOLE_ARGUMENTS.stats: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - adjusted_ts)) - elif CONSOLE_ARGUMENTS.jiffies: - output.write( - "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, - dstVal, edgetype, timestamp, - jiffies)) - else: - output.write( - "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, - edgetype, timestamp)) - f.close() - output.close() - pb.close() - return total_edges - -def read_single_graph(file_name, threshold): - graph = [] - edge_cnt = 0 - with open(file_name, 'r') as f: - for line in f: - try: - edge = line.strip().split("\t") - new_edge = [edge[0], edge[1]] - attributes = edge[2].strip().split(":") - source_node_type = attributes[0] - destination_node_type = attributes[1] - edge_type = attributes[2] - edge_order = attributes[3] - - new_edge.append(source_node_type) - new_edge.append(destination_node_type) - new_edge.append(edge_type) - new_edge.append(edge_order) - graph.append(new_edge) - edge_cnt += 1 - except: - print("{}".format(line)) - f.close() - graph.sort(key=lambda e: e[5]) - if len(graph) <= threshold: - return graph - else: - return graph[:threshold] - - -def process_graph(name, threshold): - graph = read_single_graph(name, threshold) - result_graph = nx.DiGraph() - cnt = 0 - for num, edge in enumerate(graph): - cnt += 1 - src, dst, src_type, dst_type, edge_type = edge[:5] - if True:# src_type in valid_node_type and dst_type in valid_node_type: - if not result_graph.has_node(src): - result_graph.add_node(src, type=src_type) - if not result_graph.has_node(dst): - result_graph.add_node(dst, type=dst_type) - if not result_graph.has_edge(src, dst): - result_graph.add_edge(src, dst, type=edge_type) - if bidirection: - result_graph.add_edge(dst, src, type='reverse_{}'.format(edge_type)) - return cnt, result_graph - - -node_type_list = [] -edge_type_list = [] -node_type_dict = {} -edge_type_dict = {} - - -def format_graph(g, name): - new_g = nx.DiGraph() - node_map = {} - node_cnt = 0 - for n in g.nodes: - node_map[n] = node_cnt - new_g.add_node(node_cnt, type=g.nodes[n]['type']) - node_cnt += 1 - for e in g.edges: - new_g.add_edge(node_map[e[0]], node_map[e[1]], type=g.edges[e]['type']) - for n in new_g.nodes: - node_type = new_g.nodes[n]['type'] - if not node_type in node_type_dict: - node_type_list.append(node_type) - node_type_dict[node_type] = 1 - else: - node_type_dict[node_type] += 1 - for e in new_g.edges: - edge_type = new_g.edges[e]['type'] - if not edge_type in edge_type_dict: - edge_type_list.append(edge_type) - edge_type_dict[edge_type] = 1 - else: - edge_type_dict[edge_type] += 1 - for n in new_g.nodes: - new_g.nodes[n]['type'] = node_type_list.index(new_g.nodes[n]['type']) - for e in new_g.edges: - new_g.edges[e]['type'] = edge_type_list.index(new_g.edges[e]['type']) - with open('{}.json'.format(name), 'w', encoding='utf-8') as f: - json.dump(nx.node_link_data(new_g), f) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Convert CamFlow JSON to Unicorn edgelist') - args = parser.parse_args() - args.stats = False - args.verbose = False - args.jiffies = False - args.input = '../data/wget/raw/' - args.output = '../data/wget/processed/' - args.final_output = '../data/wget/final/' - args.noencode = False - if not os.path.exists(args.input): - os.mkdir(args.input) - if not os.path.exists(args.output): - os.mkdir(args.output) - if not os.path.exists(args.final_output): - os.mkdir(args.final_output) - CONSOLE_ARGUMENTS = args - - if args.verbose: - logging.basicConfig(filename=args.log, level=logging.DEBUG) - cnt = 0 - for fname in os.listdir(args.input): - cnt += 1 - node_map = dict() - parse_all_nodes(args.input + '/{}'.format(fname), node_map) - total_edges = parse_all_edges(args.input + '/{}'.format(fname), args.output + '/{}.log'.format(cnt), node_map, - args.noencode) - if args.stats: - total_nodes = len(node_map) - stats = open(args.stats_file + '/{}.log'.format(cnt), "a+") - stats.write("{},{},{}\n".format(args.input + '/{}'.format(fname), total_nodes, total_edges)) - - bidirection = False - threshold = 10000000 # infinity - interaction_dict = [] - graph_cnt = 0 - result_graphs = [] - input = args.output - base = args.final_output - - line_cnt = 0 - for i in tqdm(range(cnt)): - single_cnt, result_graph = process_graph('{}{}.log'.format(input, i + 1), threshold) - format_graph(result_graph, '{}{}'.format(base, i)) - line_cnt += single_cnt - - print(line_cnt // 150) - print(len(node_type_list)) - print(node_type_dict) - print(len(edge_type_list)) - print(edge_type_dict) -