diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 22b75bfacc1affcdc8d9e94144606c1fc1eadc6e..f40eb3e6a26ba58cdbed986f27f926b54c24ec37 --- a/README.md +++ b/README.md @@ -1,15 +1,16 @@ -# CONTINUUM-FEDHE-Graph +# FEDHE-Graph -Welcome to the official repository housing the FEDHE-Graph implementation for our solution Continuum! This repository provides you with the necessary tools and resources to leverage federated learning techniques within the context of Continuum, a comprehensive framework for federated learning research. +Welcome to the official repository housing the FEDHE-Graph implementation for Magic! This repository provides you with the necessary tools and resources to leverage federated learning techniques within the context of Magic, a comprehensive framework for federated learning research.  - +Original project: https://github.com/FDUDSDE/MAGIC ## Environment Setup -The command are used in an environnement that consist of Windows 11 with anaconda installed +The command are used in an environnement that consist of Ubuntu 22.04 with miniconda installed +Original project: https://github.com/FDUDSDE/MAGIC First create the conda environnement for fedml with MPI support @@ -21,13 +22,13 @@ conda install -c conda-forge mpi4py openmpi pip install "fedml[MPI]" ``` -Clone the Continuum FedML project onto your current folder +Clone the MAGIC FedML project onto your current folder ``` -git clone https://github.com/kamelferrahi/[MAGIC_FEDERATED_FedML](https://github.com/kamelferrahi/Continuum_FL) +git clone https://github.com/kamelferrahi/MAGIC_FEDERATED_FedML ``` -Install the necessary packages for Continuum to run +Install the necessary packages for Magic to run ``` conda install -c conda-forge aiohttp=3.9.1 aiosignal=1.3.1 anyio=4.2.0 attrdict=2.0.1 attrs=23.2.0 blis=0.7.11 boto3=1.34.12 botocore=1.34.12 brotli=1.1.0 catalogue=2.0.10 certifi=2023.11.17 chardet=5.2.0 charset-normalizer=3.3.2 click=8.1.7 cloudpathlib=0.16.0 confection=0.1.4 contourpy=1.2.0 cycler=0.12.1 cymem=2.0.8 dgl=1.1.3 dill=0.3.7 fastapi=0.92.0 fedml=0.8.13.post2 filelock=3.13.1 fonttools=4.47.0 frozenlist=1.4.1 fsspec=2023.12.2 gensim=4.3.2 gevent=23.9.1 geventhttpclient=2.0.9 gitdb=4.0.11 GitPython=3.1.40 GPUtil=1.4.0 graphviz=0.8.4 greenlet=3.0.3 h11=0.14.0 h5py=3.10.0 httpcore=1.0.2 httpx=0.26.0 idna=3.6 Jinja2=3.1.2 jmespath=1.0.1 joblib=1.3.2 kiwisolver=1.4.5 langcodes=3.3.0 MarkupSafe=2.1.3 matplotlib=3.8.2 mpi4py=3.1.3 mpmath=1.3.0 multidict=6.0.4 multiprocess=0.70.15 murmurhash=1.0.10 networkx=2.8.8 ntplib=0.4.0 numpy=1.26.3 nvidia-cublas-cu12=12.1.3.1 nvidia-cuda-cupti-cu12=12.1.105 nvidia-cuda-nvrtc-cu12=12.1.105 nvidia-cuda-runtime-cu12=12.1.105 nvidia-cudnn-cu12=8.9.2.26 nvidia-cufft-cu12=11.0.2.54 nvidia-curand-cu12=10.3.2.106 nvidia-cusolver-cu12=11.4.5.107 nvidia-cusparse-cu12=12.1.0.106 nvidia-nccl-cu12=2.18.1 nvidia-nvjitlink-cu12=12.3.101 nvidia-nvtx-cu12=12.1.105 onnx=1.15.0 packaging=23.2 paho-mqtt=1.6.1 pandas=2.1.4 pathtools=0.1.2 pillow=10.2.0 preshed=3.0.9 prettytable=3.9.0 promise=2.3 protobuf=3.20.3 psutil=5.9.7 py-machineid=0.4.6 pydantic=1.10.13 pyparsing=3.1.1 python-dateutil=2.8.2 python-rapidjson=1.14 pytz=2023.3.post1 PyYAML=6.0.1 redis=5.0.1 requests=2.31.0 s3transfer=0.10.0 scikit-learn=1.3.2 scipy=1.11.4 sentry-sdk=1.39.1 setproctitle=1.3.3 shortuuid=1.0.11 six=1.16.0 smart-open=6.3.0 smmap=5.0.1 sniffio=1.3.0 spacy=3.7.2 spacy-legacy=3.0.12 spacy-loggers=1.0.5 SQLAlchemy=2.0.25 srsly=2.4.8 starlette=0.25.0 sympy=1.12 thinc=8.2.2 threadpoolctl=3.2.0 torch=2.1.2 torch-cluster=1.6.3 torch-scatter=2.1.2 torch-sparse=0.6.18 torch-spline-conv=1.2.2 torch_geometric=2.4.0 torchvision=0.16.2 tqdm=4.66.1 triton=2.1.0 tritonclient=2.41.0 typer=0.9.0 typing_extensions=4.9.0 tzdata=2023.4 tzlocal=5.2 urllib3=2.0.7 uvicorn=0.25.0 wandb=0.13.2 wasabi=1.1.2 wcwidth=0.2.12 weasel=0.3.4 websocket-client=1.7.0 wget=3.2 yarl=1.9.4 zope.event=5.0 zope.interface=6.1 @@ -54,7 +55,7 @@ train_args: The algorithm tested are `FedAvg`, `FedProx` and `FedOpt` ## Datasets -The experiments utilize datasets similar to those in the original Continuum project. To change datasets, edit the `fedml_config.yaml` file: +The experiments utilize datasets similar to those in the original Magic project. To change datasets, edit the `fedml_config.yaml` file: ``` data_args: dataset: "wget" diff --git a/a.yaml b/a.yaml new file mode 100755 index 0000000000000000000000000000000000000000..9f8ce6e7fa889de842853b12c3a723d56ab60899 --- /dev/null +++ b/a.yaml @@ -0,0 +1,83 @@ +common_args: + training_type: "cross_silo" + scenario: "horizontal" + using_mlops: false + config_version: release + name: "exp" + project: "runs/train" + exist_ok: false + random_seed: 0 + + common_args: + training_type: "simulation" + random_seed: 0 +comm_args: + backend: "MPI" + is_mobile: 0 + + +data_args: + dataset: "clintox" + data_cache_dir: ~/fedgraphnn_data/ + part_file: ~/fedgraphnn_data/partition + partition_method: "hetero" + partition_alpha: 0.5 + +model_args: + model: "graphsage" + hidden_size: 32 + node_embedding_dim: 32 + graph_embedding_dim: 64 + readout_hidden_dim: 64 + alpha: 0.2 + num_heads: 2 + dropout: 0.3 + normalize_features: False + normalize_adjacency: False + sparse_adjacency: False + model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically + global_model_file_path: "./model_file_cache/global_model.pt" + +environment_args: + bootstrap: config/bootstrap.sh + +train_args: + federated_optimizer: "FedAvg" + client_id_list: + client_num_in_total: 3 + client_num_per_round: 3 + comm_round: 100 + epochs: 5 + batch_size: 64 + client_optimizer: sgd + learning_rate: 0.03 + weight_decay: 0.001 + metric: "prc-auc" + server_optimizer: sgd + lr: 0.001 + server_lr: 0.001 + wd: 0.001 + ci: 0 + server_momentum: 0.9 + +validation_args: + frequency_of_the_test: 1 + +device_args: + worker_num: 3 + using_gpu: false + gpu_mapping_file: config/gpu_mapping.yaml + gpu_mapping_key: mapping_fedgraphnn_sp + +comm_args: + backend: "MQTT_S3" + mqtt_config_path: config/mqtt_config.yaml + s3_config_path: config/s3_config.yaml + + +tracking_args: + # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/ + enable_wandb: false + wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 + wandb_project: fedml + wandb_name: fedml_torch_moleculenet diff --git a/a.yml b/a.yml new file mode 100755 index 0000000000000000000000000000000000000000..adfbd8733fcee48a46e87a24cc48cbb1a0241f87 --- /dev/null +++ b/a.yml @@ -0,0 +1,68 @@ +common_args: + training_type: "simulation" + random_seed: 0 + +data_args: + dataset: "clintox" + data_cache_dir: ~/fedgraphnn_data/ + part_file: ~/fedgraphnn_data/partition + partition_method: "hetero" + partition_alpha: 0.5 + +model_args: + model: "graphsage" + hidden_size: 32 + node_embedding_dim: 32 + graph_embedding_dim: 64 + readout_hidden_dim: 64 + alpha: 0.2 + num_heads: 2 + dropout: 0.3 + normalize_features: False + normalize_adjacency: False + sparse_adjacency: False + model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically + global_model_file_path: "./model_file_cache/global_model.pt" + +environment_args: + bootstrap: config/bootstrap.sh + +train_args: + federated_optimizer: "FedAvg" + client_id_list: + client_num_in_total: 3 + client_num_per_round: 3 + comm_round: 100 + epochs: 5 + batch_size: 64 + client_optimizer: sgd + learning_rate: 0.03 + weight_decay: 0.001 + metric: "prc-auc" + server_optimizer: sgd + lr: 0.001 + server_lr: 0.001 + wd: 0.001 + ci: 0 + server_momentum: 0.9 + +validation_args: + frequency_of_the_test: 1 + +device_args: + worker_num: 2 + using_gpu: false + gpu_mapping_file: config/gpu_mapping.yaml + gpu_mapping_key: mapping_fedgraphnn_sp + +comm_args: + backend: "MPI" + is_mobile: 0 + + +tracking_args: + # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/ + enable_wandb: false + wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 + wandb_project: fedml + wandb_name: fedml_torch_moleculenet diff --git a/checkpoints - Copie/checkpoint-SC2.pt b/checkpoints - Copie/checkpoint-SC2.pt deleted file mode 100644 index 963a001f792613ee2fb2f5f22cb25b599a8dedbb..0000000000000000000000000000000000000000 Binary files a/checkpoints - Copie/checkpoint-SC2.pt and /dev/null differ diff --git a/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt b/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt deleted file mode 100644 index 20073e1503d9f57b75c278bdb89c757d3aaccb30..0000000000000000000000000000000000000000 Binary files a/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt and /dev/null differ diff --git a/checkpoints - Copie/checkpoint-cadets-e3.pt b/checkpoints - Copie/checkpoint-cadets-e3.pt deleted file mode 100644 index ea7ee25689066513d46aea55e07f71df859ec8c7..0000000000000000000000000000000000000000 Binary files a/checkpoints - Copie/checkpoint-cadets-e3.pt and /dev/null differ diff --git a/checkpoints - Copie/checkpoint-clearscope-e3.pt b/checkpoints - Copie/checkpoint-clearscope-e3.pt deleted file mode 100644 index 5edaf9d73bdde92067b7e2f6f0129a4673feb945..0000000000000000000000000000000000000000 Binary files a/checkpoints - Copie/checkpoint-clearscope-e3.pt and /dev/null differ diff --git a/checkpoints - Copie/checkpoint-streamspot.pt b/checkpoints - Copie/checkpoint-streamspot.pt deleted file mode 100644 index ec1ed84e3dc16589af58039e87a34f21cc579cd5..0000000000000000000000000000000000000000 Binary files a/checkpoints - Copie/checkpoint-streamspot.pt and /dev/null differ diff --git a/checkpoints - Copie/checkpoint-theia-e3.pt b/checkpoints - Copie/checkpoint-theia-e3.pt deleted file mode 100644 index 1513228c205f790835edb723dffe967eec960732..0000000000000000000000000000000000000000 Binary files a/checkpoints - Copie/checkpoint-theia-e3.pt and /dev/null differ diff --git a/checkpoints - Copie/checkpoint-trace-e3.pt b/checkpoints - Copie/checkpoint-trace-e3.pt deleted file mode 100644 index 755598d71f01bee477182c9fe454b4646ac7ffcd..0000000000000000000000000000000000000000 Binary files a/checkpoints - Copie/checkpoint-trace-e3.pt and /dev/null differ diff --git a/checkpoints - Copie/checkpoint-wget-long.pt b/checkpoints - Copie/checkpoint-wget-long.pt deleted file mode 100644 index f5c8f96c5be068033746e3b9271b983c438bee54..0000000000000000000000000000000000000000 Binary files a/checkpoints - Copie/checkpoint-wget-long.pt and /dev/null differ diff --git a/checkpoints - Copie/checkpoint-wget.pt b/checkpoints - Copie/checkpoint-wget.pt deleted file mode 100644 index 1d5f6c0071b7c0d0e136aa6112ade76c083984fc..0000000000000000000000000000000000000000 Binary files a/checkpoints - Copie/checkpoint-wget.pt and /dev/null differ diff --git a/checkpoints/checkpoint-SC2.pt b/checkpoints/checkpoint-SC2.pt deleted file mode 100644 index aec48a89437ad5ef4c3199b0547a3efc9365115f..0000000000000000000000000000000000000000 Binary files a/checkpoints/checkpoint-SC2.pt and /dev/null differ diff --git a/checkpoints/checkpoint-Unicorn-Cadets.pt b/checkpoints/checkpoint-Unicorn-Cadets.pt deleted file mode 100644 index 90d5dcfbb036da54972648625073287bad1b6986..0000000000000000000000000000000000000000 Binary files a/checkpoints/checkpoint-Unicorn-Cadets.pt and /dev/null differ diff --git a/checkpoints/checkpoint-cadets-e3 - Copie (2).pt b/checkpoints/checkpoint-cadets-e3 - Copie (2).pt deleted file mode 100644 index 5f62ae9abed4280d950ebf431a30b2f883b386a2..0000000000000000000000000000000000000000 Binary files a/checkpoints/checkpoint-cadets-e3 - Copie (2).pt and /dev/null differ diff --git a/checkpoints/checkpoint-cadets-e3.pt b/checkpoints/checkpoint-cadets-e3.pt deleted file mode 100644 index 157f6b404a8228c9da2483564ab892e0fd4f8a88..0000000000000000000000000000000000000000 Binary files a/checkpoints/checkpoint-cadets-e3.pt and /dev/null differ diff --git a/checkpoints/checkpoint-clearscope-e3.pt b/checkpoints/checkpoint-clearscope-e3.pt deleted file mode 100644 index 0ebdb975ec2e7455daed67ab002f1b39a9984c1b..0000000000000000000000000000000000000000 Binary files a/checkpoints/checkpoint-clearscope-e3.pt and /dev/null differ diff --git a/checkpoints/checkpoint-streamspot.pt b/checkpoints/checkpoint-streamspot.pt deleted file mode 100644 index 0065a7b8b52665153af952617ab45cde61fe14eb..0000000000000000000000000000000000000000 Binary files a/checkpoints/checkpoint-streamspot.pt and /dev/null differ diff --git a/checkpoints/checkpoint-theia-e3.pt b/checkpoints/checkpoint-theia-e3.pt deleted file mode 100644 index 1f0fcc66e18b3e6d534f99e090e5387807451a64..0000000000000000000000000000000000000000 Binary files a/checkpoints/checkpoint-theia-e3.pt and /dev/null differ diff --git a/checkpoints/checkpoint-trace-e3.pt b/checkpoints/checkpoint-trace-e3.pt deleted file mode 100644 index f050065e56c009605562ca6e980652b6c2a33938..0000000000000000000000000000000000000000 Binary files a/checkpoints/checkpoint-trace-e3.pt and /dev/null differ diff --git a/checkpoints/checkpoint-wget.pt b/checkpoints/checkpoint-wget.pt deleted file mode 100644 index 8404279998dc213fba6b6f626d55c37ccb7c8ca9..0000000000000000000000000000000000000000 Binary files a/checkpoints/checkpoint-wget.pt and /dev/null differ diff --git a/data/__pycache__/data_loader.cpython-311.pyc b/data/__pycache__/data_loader.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..b17492ecf9ba2a9c984303dae0e410e17b03af83 Binary files /dev/null and b/data/__pycache__/data_loader.cpython-311.pyc differ diff --git a/data/cadets/graphs.zip b/data/cadets/graphs.zip new file mode 100755 index 0000000000000000000000000000000000000000..7b014b2ccfa0594ff4dc0e742a2d8453aafeb26d Binary files /dev/null and b/data/cadets/graphs.zip differ diff --git a/data/data_loader.py b/data/data_loader.py new file mode 100755 index 0000000000000000000000000000000000000000..f52db36d726a6d6bd742e77841eb3df28b2de34b --- /dev/null +++ b/data/data_loader.py @@ -0,0 +1,332 @@ +import logging +import pickle as pkl +import random + +import torch.utils.data as data + +from fedml.core import partition_class_samples_with_dirichlet_distribution +import dgl +import networkx as nx +import json +from tqdm import tqdm +import os +import numpy as np +from utils.loaddata import load_rawdata, load_batch_level_dataset, load_entity_level_dataset, load_metadata + + +class WgetDataset(dgl.data.DGLDataset): + def process(self): + pass + + def __init__(self, name): + super(WgetDataset, self).__init__(name=name) + if name == 'wget': + pathattack = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalattack' + pathbenin = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalbenin' + num_graphs_benin = 125 + num_graphs_attack = 25 + self.graphs = [] + self.labels = [] + print('Loading {} dataset...'.format(name)) + for i in tqdm(range(num_graphs_benin)): + idx = i + g = dgl.from_networkx( + nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))), + node_attrs=['type'], + edge_attrs=['type'] + ) + self.graphs.append(g) + self.labels.append(0) + + for i in tqdm(range(num_graphs_attack)): + idx = i + g = dgl.from_networkx( + nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))), + node_attrs=['type'], + edge_attrs=['type'] + ) + self.graphs.append(g) + self.labels.append(1) + else: + raise NotImplementedError + + def __getitem__(self, i): + return self.graphs[i], self.labels[i] + + def __len__(self): + return len(self.graphs) + +def darpa_split(name): + device = "cpu" + path = './data/' + name + '/' + metadata = load_metadata(name) + n_train = metadata['n_train'] + train_dataset = [] + train_labels = [] + for i in range(n_train): + g = load_entity_level_dataset(name, 'train', i).to(device) + train_dataset.append(g) + train_labels.append(0) + + return ( + train_dataset, + train_labels, + [], + [], + [], + [] + ) + + +def create_random_split(name): + dataset = load_rawdata(name) + # Random 80/10/10 split as suggested + train_range = (0, int(0.8 * len(dataset))) + val_range = ( + int(0.8 * len(dataset)), + int(0.8 * len(dataset)) + int(0.1 * len(dataset)), + ) + test_range = ( + int(0.8 * len(dataset)) + int(0.1 * len(dataset)), + len(dataset), + ) + + all_idxs = list(range(len(dataset))) + random.shuffle(all_idxs) + + train_dataset = [ + dataset[all_idxs[i]] for i in range(train_range[0], train_range[1]) + ] + + train_labels = [dataset[all_idxs[i]][1] for i in range(train_range[0], train_range[1])] + + val_dataset = [ + dataset[all_idxs[i]] for i in range(val_range[0], val_range[1]) + ] + val_labels = [dataset[all_idxs[i]][1] for i in range(val_range[0], val_range[1])] + + test_dataset = [ + dataset[all_idxs[i]] for i in range(test_range[0], test_range[1]) + ] + test_labels = [dataset[all_idxs[i]][1] for i in range(test_range[0], test_range[1])] + + return ( + train_dataset, + train_labels, + val_dataset, + val_labels, + test_dataset, + test_labels, + ) + + + +def partition_data_by_sample_size( + args, client_number, name, uniform=True, compact=True +): + if (name == 'wget' or name == 'streamspot'): + ( + train_dataset, + train_labels, + val_dataset, + val_labels, + test_dataset, + test_labels, + ) = create_random_split(name) + else: + ( + train_dataset, + train_labels, + val_dataset, + val_labels, + test_dataset, + test_labels, + ) = darpa_split(name) + + num_train_samples = len(train_dataset) + num_val_samples = len(val_dataset) + num_test_samples = len(test_dataset) + + train_idxs = list(range(num_train_samples)) + val_idxs = list(range(num_val_samples)) + test_idxs = list(range(num_test_samples)) + + random.shuffle(train_idxs) + random.shuffle(val_idxs) + random.shuffle(test_idxs) + + partition_dicts = [None] * client_number + + if uniform: + clients_idxs_train = np.array_split(train_idxs, client_number) + clients_idxs_val = np.array_split(val_idxs, client_number) + clients_idxs_test = np.array_split(test_idxs, client_number) + else: + clients_idxs_train = create_non_uniform_split( + args, train_idxs, client_number, True + ) + clients_idxs_val = create_non_uniform_split( + args, val_idxs, client_number, False + ) + clients_idxs_test = create_non_uniform_split( + args, test_idxs, client_number, False + ) + + labels_of_all_clients = [] + for client in range(client_number): + client_train_idxs = clients_idxs_train[client] + client_val_idxs = clients_idxs_val[client] + client_test_idxs = clients_idxs_test[client] + + train_dataset_client = [ + train_dataset[idx] for idx in client_train_idxs + ] + train_labels_client = [train_labels[idx] for idx in client_train_idxs] + labels_of_all_clients.append(train_labels_client) + + val_dataset_client = [val_dataset[idx] for idx in client_val_idxs] + val_labels_client = [val_labels[idx] for idx in client_val_idxs] + + test_dataset_client = [test_dataset[idx] for idx in client_test_idxs] + test_labels_client = [test_labels[idx] for idx in client_test_idxs] + + + partition_dict = { + "train": train_dataset_client, + "val": val_dataset_client, + "test": test_dataset_client, + } + + partition_dicts[client] = partition_dict + global_data_dict = { + "train": train_dataset, + "val": val_dataset, + "test": test_dataset, + } + + return global_data_dict, partition_dicts + +def load_partition_data( + args, + client_number, + name, + uniform=True, + global_test=True, + compact=True, + normalize_features=False, + normalize_adj=False, +): + global_data_dict, partition_dicts = partition_data_by_sample_size( + args, client_number, name, uniform, compact=compact + ) + + data_local_num_dict = dict() + train_data_local_dict = dict() + val_data_local_dict = dict() + test_data_local_dict = dict() + + + + # IT IS VERY IMPORTANT THAT THE BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE. + train_data_global = global_data_dict["train"] + val_data_global = global_data_dict["val"] + test_data_global = global_data_dict["test"] + train_data_num = len(global_data_dict["train"]) + val_data_num = len(global_data_dict["val"]) + test_data_num = len(global_data_dict["test"]) + + for client in range(client_number): + train_dataset_client = partition_dicts[client]["train"] + val_dataset_client = partition_dicts[client]["val"] + test_dataset_client = partition_dicts[client]["test"] + + data_local_num_dict[client] = len(train_dataset_client) + train_data_local_dict[client] = train_dataset_client, + + val_data_local_dict[client] = val_dataset_client + + test_data_local_dict[client] = ( + test_data_global + if global_test + else test_dataset_client + + ) + + logging.info( + "Client idx = {}, local sample number = {}".format( + client, len(train_dataset_client) + ) + ) + + return ( + train_data_num, + val_data_num, + test_data_num, + train_data_global, + val_data_global, + test_data_global, + data_local_num_dict, + train_data_local_dict, + val_data_local_dict, + test_data_local_dict, + ) + + + +def load_batch_level_dataset_main(name): + dataset = get_data(name) + graph, _ = dataset[0] + node_feature_dim = 0 + for g, _ in dataset: + node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item()) + edge_feature_dim = 0 + for g, _ in dataset: + edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item()) + node_feature_dim += 1 + edge_feature_dim += 1 + full_dataset = [i for i in range(len(dataset))] + train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0] + print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim)) + + return {'dataset': dataset, + 'train_index': train_dataset, + 'full_index': full_dataset, + 'n_feat': node_feature_dim, + 'e_feat': edge_feature_dim} + + +class GraphDataset(dgl.data.DGLDataset): + def __init__(self, graph_label_list): + super(GraphDataset, self).__init__(name="wget") + self.graph_label_list = graph_label_list + + def __len__(self): + return len(self.graph_label_list) + + def __getitem__(self, idx): + graph, label = self.graph_label_list[idx] + # Convert the graph to a DGLGraph to work with DGL + return graph, label + + +def transform_data(data): + dataset = GraphDataset(data[0]) + graph, _ = dataset[0] + node_feature_dim = 0 + for g, _ in dataset: + node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item()) + edge_feature_dim = 0 + for g, _ in dataset: + edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item()) + node_feature_dim += 1 + edge_feature_dim += 1 + full_dataset = [i for i in range(len(dataset))] + train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0] + print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim)) + + return {'dataset': dataset, + 'train_index': train_dataset, + 'full_index': full_dataset, + 'n_feat': node_feature_dim, + 'e_feat': edge_feature_dim} + diff --git a/data/theia/graphs.zip b/data/theia/graphs.zip new file mode 100755 index 0000000000000000000000000000000000000000..3a98e4af6ec79f7199a8d843ff734bdce4ac59f9 Binary files /dev/null and b/data/theia/graphs.zip differ diff --git a/checkpoints/checkpoint-wget-long.pt b/data/trace/graphs.zip old mode 100644 new mode 100755 similarity index 60% rename from checkpoints/checkpoint-wget-long.pt rename to data/trace/graphs.zip index 040897b99810c72f6bb6d2314aab468b3972de3e..98fdf390d57c6cec2ee07134ef7d4fe25accac2f Binary files a/checkpoints/checkpoint-wget-long.pt and b/data/trace/graphs.zip differ diff --git a/distance_save_cadets.pkl b/distance_save_cadets.pkl new file mode 100644 index 0000000000000000000000000000000000000000..1f2f428d6e2e6ffb8afbfd0d0669c46cab8718c6 Binary files /dev/null and b/distance_save_cadets.pkl differ diff --git a/eval_result/distance_save_cadets-e3 - Copie.pkl b/eval_result/distance_save_cadets-e3 - Copie.pkl deleted file mode 100644 index 7f737190e44abb34ebea3e6763c64b9f1ae3a89a..0000000000000000000000000000000000000000 Binary files a/eval_result/distance_save_cadets-e3 - Copie.pkl and /dev/null differ diff --git a/eval_result/distance_save_cadets-e3.pkl b/eval_result/distance_save_cadets-e3.pkl deleted file mode 100644 index 10c18f8c50bfce89a1afd0811cc24ce8e0595786..0000000000000000000000000000000000000000 Binary files a/eval_result/distance_save_cadets-e3.pkl and /dev/null differ diff --git a/eval_result/distance_save_cadets.pkl b/eval_result/distance_save_cadets.pkl new file mode 100644 index 0000000000000000000000000000000000000000..2bb4de8e00a64a22e53b1a5187821e6e53dfffcc Binary files /dev/null and b/eval_result/distance_save_cadets.pkl differ diff --git a/eval_result/distance_save_theia-e3.pkl b/eval_result/distance_save_theia-e3.pkl deleted file mode 100644 index c5d31bab1acc6e943c746a154eec3a9370ccb644..0000000000000000000000000000000000000000 Binary files a/eval_result/distance_save_theia-e3.pkl and /dev/null differ diff --git a/eval_result/distance_save_trace-e3 - Copie.pkl b/eval_result/distance_save_trace-e3 - Copie.pkl deleted file mode 100644 index 709b4f83bc0d8cd4d8be175fed1c54d366c9de16..0000000000000000000000000000000000000000 Binary files a/eval_result/distance_save_trace-e3 - Copie.pkl and /dev/null differ diff --git a/eval_result/distance_save_trace-e3.pkl b/eval_result/distance_save_trace-e3.pkl deleted file mode 100644 index 709b4f83bc0d8cd4d8be175fed1c54d366c9de16..0000000000000000000000000000000000000000 Binary files a/eval_result/distance_save_trace-e3.pkl and /dev/null differ diff --git a/fedml_config.yaml b/fedml_config.yaml old mode 100644 new mode 100755 index 40cd616a925f3514dda97dd8ba1b20aaead81744..66714a20c7a2b471676dac8be0801b3ae132e5cf --- a/fedml_config.yaml +++ b/fedml_config.yaml @@ -1,53 +1,49 @@ common_args: - training_type: "cross_silo" - scenario: "horizontal" - using_mlops: false - config_version: release - name: "exp" - project: "runs/train" - exist_ok: false + training_type: "simulation" random_seed: 0 data_args: - dataset: "trace-e3" + dataset: "wget" + data_cache_dir: ~/fedgraphnn_data/ + part_file: ~/fedgraphnn_data/partition model_args: model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically global_model_file_path: "./model_file_cache/global_model.pt" +environment_args: + bootstrap: config/bootstrap.sh train_args: federated_optimizer: "FedAvg" client_id_list: - client_num_in_total: 2 - client_num_per_round: 2 - comm_round: 1 - snapshot: 1 + client_num_in_total: 4 + client_num_per_round: 4 + comm_round: 100 + lr: 0.001 + server_lr: 0.001 + wd: 0.001 + ci: 0 + server_momentum: 0.9 validation_args: frequency_of_the_test: 1 device_args: - worker_num: 2 - using_gpu: true - gpu_mapping_file: gpu_mapping.yaml - gpu_mapping_key: mapping_config + worker_num: 4 + using_gpu: false + gpu_mapping_file: config/gpu_mapping.yaml + gpu_mapping_key: mapping_fedgraphnn_sp comm_args: - backend: "MQTT_S3" - mqtt_config_path: config/mqtt_config.yaml + backend: "MPI" + is_mobile: 0 + tracking_args: # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/ enable_wandb: false wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 wandb_project: fedml - wandb_name: fedml_torch - -# fhe_args: -# # enable_fhe: true - # scheme: ckks -# batch_size: 8192 -# scaling_factor: 52 -# file_loc: "resources/cryptoparams/" \ No newline at end of file + wandb_name: fedml_torch_moleculenet diff --git a/gpu_mapping.yaml b/gpu_mapping.yaml deleted file mode 100644 index 90b13bfe6abfa84b56845ac220a5c67ced781cbe..0000000000000000000000000000000000000000 --- a/gpu_mapping.yaml +++ /dev/null @@ -1,2 +0,0 @@ -mapping_config: - host1: [3] \ No newline at end of file diff --git a/main.py b/main.py old mode 100644 new mode 100755 index 7d7430d7fb88d4259d8096266633bb5dbd94bc8f..a1a716e17f69f73d33fa62ef5b3da1e85b90b715 --- a/main.py +++ b/main.py @@ -1,18 +1,20 @@ import logging import fedml -from utils.dataloader import load_partition_data, load_data, load_metadata, darpa_split +from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split from fedml import FedMLRunner from trainer.magic_trainer import MagicTrainer from trainer.magic_aggregator import MagicWgetAggregator -from model.model import STGNN_AutoEncoder +from model.autoencoder import build_model from utils.config import build_args from trainer.magic_trainer import MagicTrainer from trainer.magic_aggregator import MagicWgetAggregator +from trainer.single_trainer import train_single +from utils.loaddata import load_batch_level_dataset, load_metadata -def generate_dataset(name, number, nsnapshot): +def generate_dataset(name, number): ( train_data_num, val_data_num, @@ -24,7 +26,7 @@ def generate_dataset(name, number, nsnapshot): train_data_local_dict, val_data_local_dict, test_data_local_dict, - ) = load_partition_data(number, name, nsnapshot) + ) = load_partition_data(None, number, name) dataset = [ train_data_num, test_data_num, @@ -36,9 +38,9 @@ def generate_dataset(name, number, nsnapshot): len(train_data_global), ] - if (name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']): + if (name == "wget" or name == "streamspot"): - return dataset, load_data(name) + return dataset, load_batch_level_dataset(name) else: return dataset, load_metadata(name) @@ -47,47 +49,39 @@ if __name__ == "__main__": # init FedML framework args = fedml.init() # init device - device = fedml.device.get_device(args) - dataset_name = args.dataset + name = args.dataset number = args.client_num_in_total - nsnapshot = args.snapshot - dataset, metadata = generate_dataset(dataset_name, number, nsnapshot) + + dataset, metadata = generate_dataset(name, number) main_args = build_args() - if (dataset_name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']): - - main_args.max_epoch = 6 - out_dim = 64 - if (dataset_name == 'SC2'): - gnn_layer = 3 - else: - gnn_layer = 5 + if (name == "wget"): + main_args["num_hidden"] = 256 + main_args["max_epoch"] = 2 + main_args["num_layers"] = 4 n_node_feat = metadata['n_feat'] n_edge_feat = metadata['e_feat'] - use_all_hidden = True - main_args.n_dim = n_node_feat - main_args.e_dim = n_edge_feat - else: - use_all_hidden = False - n_node_feat = metadata['node_feature_dim'] - n_edge_feat = metadata['edge_feature_dim'] - #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148] - main_args.n_dim = n_node_feat - main_args.e_dim = n_edge_feat - main_args.max_epoch = 50 - out_dim = 64 - - if (dataset_name == 'cadets-e3'): - gnn_layer = 4 - else: - gnn_layer = 3 - - - - model = STGNN_AutoEncoder(main_args.n_dim, main_args.e_dim, out_dim, out_dim, gnn_layer, 4, device, nsnapshot, 'prelu', 0.1, main_args.negative_slope, True, 'BatchNorm', main_args.pooling, alpha_l=main_args.alpha_l, use_all_hidden=use_all_hidden).to(device) # Move model to GPU + main_args["n_dim"] = n_node_feat + main_args["e_dim"] = n_edge_feat + elif (name == "streamspot"): + main_args["num_hidden"] = 256 + main_args["max_epoch"] = 5 + main_args["num_layers"] = 4 + n_node_feat = metadata['n_feat'] + n_edge_feat = metadata['e_feat'] + main_args["n_dim"] = n_node_feat + main_args["e_dim"] = n_edge_feat + else: + main_args["num_hidden"] = 64 + main_args["max_epoch"] = 50 + main_args["num_layers"] = 3 + main_args["n_dim"] = metadata["node_feature_dim"] + main_args["e_dim"] = metadata["edge_feature_dim"] + + model = build_model(main_args) #train_single(main_args, model, data) - trainer = MagicTrainer(model, args, dataset_name) - aggregator = MagicWgetAggregator(model, args, dataset_name) + trainer = MagicTrainer(model, args, name) + aggregator = MagicWgetAggregator(model, args, name) fedml_runner = FedMLRunner(args, device, dataset, model, trainer, aggregator) fedml_runner.run() # start training diff --git a/model/__pycache__/autoencoder.cpython-311.pyc b/model/__pycache__/autoencoder.cpython-311.pyc index 3abc72625245a57fa06470e3659c594a0c033046..13b41e99adb07cc4cfa61dd2771b34854aa88d80 100644 Binary files a/model/__pycache__/autoencoder.cpython-311.pyc and b/model/__pycache__/autoencoder.cpython-311.pyc differ diff --git a/model/__pycache__/eval.cpython-310.pyc b/model/__pycache__/eval.cpython-310.pyc deleted file mode 100644 index 0dc58b1214be9d335558c42907cf1feb18b81969..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/eval.cpython-310.pyc and /dev/null differ diff --git a/model/__pycache__/eval.cpython-311.pyc b/model/__pycache__/eval.cpython-311.pyc old mode 100644 new mode 100755 index faade8397a5b3c546b35f34aef088d8992486510..c7d0bb5b6623e0d409dc28636e2722f423c0c7f0 Binary files a/model/__pycache__/eval.cpython-311.pyc and b/model/__pycache__/eval.cpython-311.pyc differ diff --git a/model/__pycache__/gat.cpython-310.pyc b/model/__pycache__/gat.cpython-310.pyc deleted file mode 100644 index 54fa4abdd4680faa41fbf0e19697fab2f000402b..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/gat.cpython-310.pyc and /dev/null differ diff --git a/model/__pycache__/gat.cpython-311.pyc b/model/__pycache__/gat.cpython-311.pyc old mode 100644 new mode 100755 index 94e6046f4f22c6a2f271c2ca3bded5d8f2417cfa..0c9997f570d123655a70df66a86c3003e82f25ab Binary files a/model/__pycache__/gat.cpython-311.pyc and b/model/__pycache__/gat.cpython-311.pyc differ diff --git a/model/__pycache__/loss_func.cpython-310.pyc b/model/__pycache__/loss_func.cpython-310.pyc deleted file mode 100644 index 79ac15810418709637606b15c5c43b4b96aec691..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/loss_func.cpython-310.pyc and /dev/null differ diff --git a/model/__pycache__/loss_func.cpython-311.pyc b/model/__pycache__/loss_func.cpython-311.pyc old mode 100644 new mode 100755 index 40b3c9d58012a8600cda16f7d39ee5c5891fa862..620a6bee7cb3ab8dc89bbc3448fdac41abb5f5ea Binary files a/model/__pycache__/loss_func.cpython-311.pyc and b/model/__pycache__/loss_func.cpython-311.pyc differ diff --git a/model/__pycache__/model.cpython-310.pyc b/model/__pycache__/model.cpython-310.pyc deleted file mode 100644 index 3984a3444fb13f56ea0c2629e2b0883298bf8f4d..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/model.cpython-310.pyc and /dev/null differ diff --git a/model/__pycache__/model.cpython-311.pyc b/model/__pycache__/model.cpython-311.pyc deleted file mode 100644 index d9a27f13ee6839a64257d3125b777544a866f329..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/model.cpython-311.pyc and /dev/null differ diff --git a/model/__pycache__/rnn.cpython-310.pyc b/model/__pycache__/rnn.cpython-310.pyc deleted file mode 100644 index d02d9b442912cda476cacbc18673eebe9e000da7..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/rnn.cpython-310.pyc and /dev/null differ diff --git a/model/__pycache__/rnn.cpython-311.pyc b/model/__pycache__/rnn.cpython-311.pyc deleted file mode 100644 index 9d71283ae5cbfc01ec9eb9f1448d8e8b292d1f08..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/rnn.cpython-311.pyc and /dev/null differ diff --git a/model/__pycache__/test.cpython-310.pyc b/model/__pycache__/test.cpython-310.pyc deleted file mode 100644 index 1d85be9aa8853ef325090a1b1cd0f366d3224792..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/test.cpython-310.pyc and /dev/null differ diff --git a/model/__pycache__/test.cpython-311.pyc b/model/__pycache__/test.cpython-311.pyc deleted file mode 100644 index 0d7beba377fb059135e3135b292a910477636c7c..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/test.cpython-311.pyc and /dev/null differ diff --git a/model/__pycache__/train.cpython-311.pyc b/model/__pycache__/train.cpython-311.pyc old mode 100644 new mode 100755 index b788a354802ea312364f91691db5b00e502922a6..c202e2105beb34aa9db7a6eddeffd4cef4b25fa3 Binary files a/model/__pycache__/train.cpython-311.pyc and b/model/__pycache__/train.cpython-311.pyc differ diff --git a/model/__pycache__/train_entity.cpython-310.pyc b/model/__pycache__/train_entity.cpython-310.pyc deleted file mode 100644 index cbe4cebe4c9d3a09cdc4a8b78f976f98a23baef1..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/train_entity.cpython-310.pyc and /dev/null differ diff --git a/model/__pycache__/train_entity.cpython-311.pyc b/model/__pycache__/train_entity.cpython-311.pyc deleted file mode 100644 index cf1b64f3a76e749403864eb4a4a10154a8003a84..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/train_entity.cpython-311.pyc and /dev/null differ diff --git a/model/__pycache__/train_graph.cpython-310.pyc b/model/__pycache__/train_graph.cpython-310.pyc deleted file mode 100644 index cf931f350c1de50fb725f95c7ebc73f1c6facc66..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/train_graph.cpython-310.pyc and /dev/null differ diff --git a/model/__pycache__/train_graph.cpython-311.pyc b/model/__pycache__/train_graph.cpython-311.pyc deleted file mode 100644 index d98093beaf11c1ffeb71b9abc556e33503934038..0000000000000000000000000000000000000000 Binary files a/model/__pycache__/train_graph.cpython-311.pyc and /dev/null differ diff --git a/model/autoencoder.py b/model/autoencoder.py new file mode 100755 index 0000000000000000000000000000000000000000..da51f590f02076ad43a0f8f40037ae4ad5b4c761 --- /dev/null +++ b/model/autoencoder.py @@ -0,0 +1,179 @@ +from .gat import GAT +from utils.utils import create_norm +from functools import partial +from itertools import chain +from .loss_func import sce_loss +import torch +import torch.nn as nn +import dgl +import random + + +def build_model(args): + num_hidden = args["num_hidden"] + num_layers = args["num_layers"] + negative_slope = args["negative_slope"] + mask_rate = args["mask_rate"] + alpha_l = args["alpha_l"] + n_dim = args["n_dim"] + e_dim = args["e_dim"] + + model = GMAEModel( + n_dim=n_dim, + e_dim=e_dim, + hidden_dim=num_hidden, + n_layers=num_layers, + n_heads=4, + activation="prelu", + feat_drop=0.1, + negative_slope=negative_slope, + residual=True, + mask_rate=mask_rate, + norm='BatchNorm', + loss_fn='sce', + alpha_l=alpha_l + ) + return model + + +class GMAEModel(nn.Module): + def __init__(self, n_dim, e_dim, hidden_dim, n_layers, n_heads, activation, + feat_drop, negative_slope, residual, norm, mask_rate=0.5, loss_fn="sce", alpha_l=2): + super(GMAEModel, self).__init__() + self._mask_rate = mask_rate + self._output_hidden_size = hidden_dim + self.recon_loss = nn.BCELoss(reduction='mean') + + def init_weights(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform(m.weight) + nn.init.constant_(m.bias, 0) + + self.edge_recon_fc = nn.Sequential( + nn.Linear(hidden_dim * n_layers * 2, hidden_dim), + nn.LeakyReLU(negative_slope), + nn.Linear(hidden_dim, 1), + nn.Sigmoid() + ) + self.edge_recon_fc.apply(init_weights) + + assert hidden_dim % n_heads == 0 + enc_num_hidden = hidden_dim // n_heads + enc_nhead = n_heads + + dec_in_dim = hidden_dim + dec_num_hidden = hidden_dim + + # build encoder + self.encoder = GAT( + n_dim=n_dim, + e_dim=e_dim, + hidden_dim=enc_num_hidden, + out_dim=enc_num_hidden, + n_layers=n_layers, + n_heads=enc_nhead, + n_heads_out=enc_nhead, + concat_out=True, + activation=activation, + feat_drop=feat_drop, + attn_drop=0.0, + negative_slope=negative_slope, + residual=residual, + norm=create_norm(norm), + encoding=True, + ) + + # build decoder for attribute prediction + self.decoder = GAT( + n_dim=dec_in_dim, + e_dim=e_dim, + hidden_dim=dec_num_hidden, + out_dim=n_dim, + n_layers=1, + n_heads=n_heads, + n_heads_out=1, + concat_out=True, + activation=activation, + feat_drop=feat_drop, + attn_drop=0.0, + negative_slope=negative_slope, + residual=residual, + norm=create_norm(norm), + encoding=False, + ) + + self.enc_mask_token = nn.Parameter(torch.zeros(1, n_dim)) + self.encoder_to_decoder = nn.Linear(dec_in_dim * n_layers, dec_in_dim, bias=False) + + # * setup loss function + self.criterion = self.setup_loss_fn(loss_fn, alpha_l) + + @property + def output_hidden_dim(self): + return self._output_hidden_size + + def setup_loss_fn(self, loss_fn, alpha_l): + if loss_fn == "sce": + criterion = partial(sce_loss, alpha=alpha_l) + else: + raise NotImplementedError + return criterion + + def encoding_mask_noise(self, g, mask_rate=0.3): + new_g = g.clone() + num_nodes = g.num_nodes() + perm = torch.randperm(num_nodes, device=g.device) + + # random masking + num_mask_nodes = int(mask_rate * num_nodes) + mask_nodes = perm[: num_mask_nodes] + keep_nodes = perm[num_mask_nodes:] + + new_g.ndata["attr"][mask_nodes] = self.enc_mask_token + + return new_g, (mask_nodes, keep_nodes) + + def forward(self, g): + loss = self.compute_loss(g) + return loss + + def compute_loss(self, g): + # Feature Reconstruction + pre_use_g, (mask_nodes, keep_nodes) = self.encoding_mask_noise(g, self._mask_rate) + pre_use_x = pre_use_g.ndata['attr'].to(pre_use_g.device) + use_g = pre_use_g + enc_rep, all_hidden = self.encoder(use_g, pre_use_x, return_hidden=True) + enc_rep = torch.cat(all_hidden, dim=1) + rep = self.encoder_to_decoder(enc_rep) + + recon = self.decoder(pre_use_g, rep) + x_init = g.ndata['attr'][mask_nodes] + x_rec = recon[mask_nodes] + loss = self.criterion(x_rec, x_init) + + # Structural Reconstruction + threshold = min(10000, g.num_nodes()) + + negative_edge_pairs = dgl.sampling.global_uniform_negative_sampling(g, threshold) + positive_edge_pairs = random.sample(range(g.number_of_edges()), threshold) + positive_edge_pairs = (g.edges()[0][positive_edge_pairs], g.edges()[1][positive_edge_pairs]) + sample_src = enc_rep[torch.cat([positive_edge_pairs[0], negative_edge_pairs[0]])].to(g.device) + sample_dst = enc_rep[torch.cat([positive_edge_pairs[1], negative_edge_pairs[1]])].to(g.device) + y_pred = self.edge_recon_fc(torch.cat([sample_src, sample_dst], dim=-1)).squeeze(-1) + y = torch.cat([torch.ones(len(positive_edge_pairs[0])), torch.zeros(len(negative_edge_pairs[0]))]).to( + g.device) + loss += self.recon_loss(y_pred, y) + return loss + + def embed(self, g): + x = g.ndata['attr'].to(g.device) + rep = self.encoder(g, x) + return rep + + @property + def enc_params(self): + return self.encoder.parameters() + + @property + def dec_params(self): + return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()]) diff --git a/model/eval.py b/model/eval.py old mode 100644 new mode 100755 index 7446db31e19a2e144eb061e0526cb44e23798935..ff777aeae5f9a8664fe89c384334eb93d02c5f44 --- a/model/eval.py +++ b/model/eval.py @@ -7,35 +7,28 @@ import numpy as np from sklearn.metrics import roc_auc_score, precision_recall_curve from sklearn.neighbors import NearestNeighbors from utils.utils import set_random_seed -from utils.dataloader import load_graph, load_data -import pickle as pkl +from data.data_loader import load_batch_level_dataset_main +from utils.loaddata import transform_graph, load_batch_level_dataset -def batch_level_evaluation(model, pooler, device, method, dataset, n_dim=0, e_dim=0): - print('Start Evaluation') +def batch_level_evaluation(model, pooler, device, method, dataset, n_dim=0, e_dim=0): model.eval() x_list = [] y_list = [] - data = load_data(dataset) + data = load_batch_level_dataset(dataset) full = data['full_index'] - labels = data['labels'] + graphs = data['dataset'] with torch.no_grad(): for i in full: - #break - g = load_graph(i, dataset, device) - label = labels[i] + g = transform_graph(graphs[i][0], n_dim, e_dim).to(device) + label = graphs[i][1] out = model.embed(g) if dataset != 'wget': - out = pooler(g[-1], out).cpu().numpy() + out = pooler(g, out).cpu().numpy() else: - out = pooler(g[-1], out, [1]).cpu().numpy() + out = pooler(g, out, [2]).cpu().numpy() y_list.append(label) x_list.append(out) - - #pkl.dump(x_list,open('xlist.pkl','wb') ) - #pkl.dump(y_list,open('ylist.pkl','wb') ) - #x_list = pkl.load(open('xlist.pkl','rb')) - #y_list = pkl.load(open('ylist.pkl','rb')) x = np.concatenate(x_list, axis=0) y = np.array(y_list) if 'knn' in method: @@ -49,18 +42,9 @@ def evaluate_batch_level_using_knn(repeat, dataset, embeddings, labels): x, y = embeddings, labels if dataset == 'streamspot': train_count = 400 - elif (dataset == 'Unicorn-Cadets' or dataset == 'wget-long'): - train_count = 70 - elif (dataset == 'wget' or dataset == 'SC2'): - train_count = 100 - else: - train_count = 30 - - if (dataset =='SC2'): - n_neighbors = min(int(train_count * 0.02), 10) else: - n_neighbors = 100 - + train_count = 100 + n_neighbors = min(int(train_count * 0.02), 10) benign_idx = np.where(y == 0)[0] attack_idx = np.where(y == 1)[0] if repeat != -1: @@ -134,65 +118,50 @@ def evaluate_batch_level_using_knn(repeat, dataset, embeddings, labels): np.random.shuffle(benign_idx) np.random.shuffle(attack_idx) x_train = x[benign_idx[:train_count]] - #x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0) x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0) y_test = np.concatenate([y[benign_idx[train_count:]], y[attack_idx]], axis=0) x_train_mean = x_train.mean(axis=0) x_train_std = x_train.std(axis=0) + for i in range(len(x_train_std)): + if (x_train_std[i] == 0 ): + x_train_std[i] = 0.000000000000001 x_train = (x_train - x_train_mean) / x_train_std x_test = (x_test - x_train_mean) / x_train_std - f1_max = 0 - for n_neighbors in range(1, train_count): - nbrs = NearestNeighbors(n_neighbors=n_neighbors) - nbrs.fit(x_train) - distances, indexes = nbrs.kneighbors(x_train, n_neighbors=n_neighbors) - mean_distance = distances.mean() * n_neighbors / (n_neighbors - 1) - #mean_distance = 0.1 - distances, indexes = nbrs.kneighbors(x_test, n_neighbors=n_neighbors) - - score = distances.mean(axis=1) / mean_distance - auc = roc_auc_score(y_test, score) - prec, rec, threshold = precision_recall_curve(y_test, score) - f1 = 2 * prec * rec / (rec + prec + 1e-9) - best_idx = np.argmax(f1) - best_thres = threshold[best_idx] - - tn = 0 - fn = 0 - tp = 0 - fp = 0 - - for i in range(len(y_test)): - if y_test[i] == 1.0 and score[i] >= best_thres: - tp += 1 - if y_test[i] == 1.0 and score[i] < best_thres: - fn += 1 - if y_test[i] == 0.0 and score[i] < best_thres: - tn += 1 - if y_test[i] == 0.0 and score[i] >= best_thres: - fp += 1 - - if (f1[best_idx]> f1_max): - f1_max = f1[best_idx] - auc_max = auc - prec_max = prec[best_idx] - rec_max = rec[best_idx] - tn_max = tn - fn_max = fn - tp_max = tp - fp_max = fp - best_n = n_neighbors - - print('AUC: {}'.format(auc_max)) - print('F1: {}'.format(f1_max)) - print('PRECISION: {}'.format(prec_max)) - print('RECALL: {}'.format(rec_max)) - print('TN: {}'.format(tn_max)) - print('FN: {}'.format(fn_max)) - print('TP: {}'.format(tp_max)) - print('FP: {}'.format(fp_max)) - print(best_n) + nbrs = NearestNeighbors(n_neighbors=n_neighbors) + nbrs.fit(x_train) + distances, indexes = nbrs.kneighbors(x_train, n_neighbors=n_neighbors) + mean_distance = distances.mean() * n_neighbors / (n_neighbors - 1) + distances, indexes = nbrs.kneighbors(x_test, n_neighbors=n_neighbors) + + score = distances.mean(axis=1) / mean_distance + auc = roc_auc_score(y_test, score) + prec, rec, threshold = precision_recall_curve(y_test, score) + f1 = 2 * prec * rec / (rec + prec + 1e-9) + best_idx = np.argmax(f1) + best_thres = threshold[best_idx] + + tn = 0 + fn = 0 + tp = 0 + fp = 0 + for i in range(len(y_test)): + if y_test[i] == 1.0 and score[i] >= best_thres: + tp += 1 + if y_test[i] == 1.0 and score[i] < best_thres: + fn += 1 + if y_test[i] == 0.0 and score[i] < best_thres: + tn += 1 + if y_test[i] == 0.0 and score[i] >= best_thres: + fp += 1 + print('AUC: {}'.format(auc)) + print('F1: {}'.format(f1[best_idx])) + print('PRECISION: {}'.format(prec[best_idx])) + print('RECALL: {}'.format(rec[best_idx])) + print('TN: {}'.format(tn)) + print('FN: {}'.format(fn)) + print('TP: {}'.format(tp)) + print('FP: {}'.format(fp)) return auc, 0.0 @@ -202,23 +171,26 @@ def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test): x_train = (x_train - x_train_mean) / x_train_std x_test = (x_test - x_train_mean) / x_train_std - if dataset == 'cadets-e3': + if dataset == 'cadets': n_neighbors = 200 else: n_neighbors = 10 - nbrs = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=-1) nbrs.fit(x_train) save_dict_path = './eval_result/distance_save_{}.pkl'.format(dataset) if not os.path.exists(save_dict_path): idx = list(range(x_train.shape[0])) - random.shuffle(idx) + random.shuffle(x_train[idx][:min(50000, x_train.shape[0])]) + print(len(x_train[idx][:min(50000, x_train.shape[0])])) distances, _ = nbrs.kneighbors(x_train[idx][:min(50000, x_train.shape[0])], n_neighbors=n_neighbors) del x_train mean_distance = distances.mean() del distances + print('here') + print(len(x_test)) distances, _ = nbrs.kneighbors(x_test, n_neighbors=n_neighbors) + print("yes") save_dict = [mean_distance, distances.mean(axis=1)] distances = distances.mean(axis=1) with open(save_dict_path, 'wb') as f: @@ -231,8 +203,7 @@ def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test): auc = roc_auc_score(y_test, score) prec, rec, threshold = precision_recall_curve(y_test, score) f1 = 2 * prec * rec / (rec + prec + 1e-9) - best_idx = np.argmax(f1) - + best_idx = -1 for i in range(len(f1)): # To repeat peak performance if dataset == 'trace' and rec[i] < 0.99979: @@ -241,7 +212,7 @@ def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test): if dataset == 'theia' and rec[i] < 0.99996: best_idx = i - 1 break - if dataset == 'cadets-e3' and rec[i] < 0.9976: + if dataset == 'cadets' and rec[i] < 0.9976: best_idx = i - 1 break best_thres = threshold[best_idx] @@ -252,11 +223,11 @@ def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test): fp = 0 for i in range(len(y_test)): if y_test[i] == 1.0 and score[i] >= best_thres: - tp += 1 + tn += 1 if y_test[i] == 1.0 and score[i] < best_thres: fn += 1 if y_test[i] == 0.0 and score[i] < best_thres: - tn += 1 + tp += 1 if y_test[i] == 0.0 and score[i] >= best_thres: fp += 1 print('AUC: {}'.format(auc)) @@ -267,4 +238,4 @@ def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test): print('FN: {}'.format(fn)) print('TP: {}'.format(tp)) print('FP: {}'.format(fp)) - return auc, 0.0, None, None \ No newline at end of file + return auc, 0.0, None, None diff --git a/model/gat.py b/model/gat.py old mode 100644 new mode 100755 index 0b2d2daeca4468f356274049f7f4454496a8ac06..c64054f0dc985f37d6c1588afc5793eaf36417a6 --- a/model/gat.py +++ b/model/gat.py @@ -34,6 +34,7 @@ class GAT(nn.Module): last_activation = create_activation(activation) if encoding else None last_residual = (encoding and residual) last_norm = norm if encoding else None + if self.n_layers == 1: self.gats.append(GATConv( n_dim, e_dim, out_dim, n_heads_out, feat_drop, attn_drop, negative_slope, diff --git a/model/loss_func.py b/model/loss_func.py old mode 100644 new mode 100755 diff --git a/model/mlp.py b/model/mlp.py new file mode 100755 index 0000000000000000000000000000000000000000..e1ee9581a033acdbec3b4501cdd45c2d72e7efd9 --- /dev/null +++ b/model/mlp.py @@ -0,0 +1,13 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class MLP(nn.Module): + def __init__(self, d_model, d_ff, dropout=0.1): + super(MLP, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) diff --git a/model/model - Copie.py b/model/model - Copie.py deleted file mode 100644 index 9bd76a77dea36485ce923267908cca624db6ae65..0000000000000000000000000000000000000000 --- a/model/model - Copie.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from utils.poolers import Pooling -from dgl.nn import EdgeGATConv, GlobalAttentionPooling -from torch.nn import GRUCell -import dgl - - - - - - -class GNN_DTDG(nn.Module): - def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device, mlp_layers, number_snapshot): - super(GNN_DTDG, self).__init__() - self.encoder = GNN_RNN(n_dim, e_dim, hidden_dim, out_dim ,n_layers, n_heads, device, number_snapshot) - self.decoder = GNN_RNN(out_dim, e_dim, hidden_dim, n_dim ,n_layers, n_heads, device, number_snapshot) - self.number_snapshot = number_snapshot - self.classifier_layers = nn.ModuleList([ - ]) - - for _ in range(mlp_layers - 1): - self.classifier_layers.extend([ - nn.Linear(out_dim, out_dim).to(device), - nn.ReLU(), - ]) - self.classifier_layers.extend([ - nn.Linear(out_dim, 1).to(device), - nn.Sigmoid() - ]) - - self.pooling_gate_nn = nn.Linear(out_dim , 1) - self.pooling = GlobalAttentionPooling(self.pooling_gate_nn) - self.pooler = Pooling("mean") - self.encoder_to_decoder = nn.Linear( out_dim, out_dim, bias=False) - - def forward(self, g): - encoded = self.encoder(g) - new_g = [] - i= 0 - for G in g: - g_encoded = G.clone() - g_encoded.ndata["attr"] = self.encoder_to_decoder(encoded[i]) - new_g.append(g_encoded) - i+=1 - - - - decoded = self.decoder(new_g) - return decoded[-1] - # x = self.pooler(G, embeddings, [1])[0] - # h_g = x.clone() - # for layer in self.classifier_layers: - # x = layer(x) - - def embed(self, g): - return self.encoder(g)[-1] - - - -class GNN_RNN(nn.Module): - def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device, number_snapshot): - super(GNN_RNN, self).__init__() - self.device = device - self.gnn_layers = nn.ModuleList([EdgeGATConv(in_feats=n_dim, edge_feats=e_dim, out_feats=out_dim, num_heads=n_heads, allow_zero_in_degree=True).to(device)]) - - self.out_dim = out_dim - - for _ in range(n_layers-1): - self.gnn_layers.append( - EdgeGATConv(in_feats=out_dim, edge_feats=e_dim, out_feats=out_dim, num_heads=n_heads, allow_zero_in_degree=True).to(device) - ) - - self.rnn_layers = nn.ModuleList([]) - - for _ in range(number_snapshot): - self.rnn_layers.append( - GRUCell(out_dim, out_dim, device = device) - ) - - self.classifier_layers = nn.ModuleList([ - ]) - - - def forward(self, g): - i = 0 - H_s = [] - for G in g: - - with G.to(self.device).local_scope(): - x = G.ndata["attr"].float() - e = G.edata["attr"].float() - for layer in self.gnn_layers: - r = layer(G, x, e) - x = torch.mean(r,dim=1).to(self.device) - del r - - #if ( i == 0): - # H = self.rnn_layers[i](x, x) - # else: - # H = self.rnn_layers[i](x, H) - - H = x - H_s.append(H) - embeddings = H.clone() - i+=1 - #x = self.pooling(g[0], x)[0] - - - return H_s \ No newline at end of file diff --git a/model/model.py b/model/model.py deleted file mode 100644 index 0389eeae588c10742143ba62a97e9b8f5639eb47..0000000000000000000000000000000000000000 --- a/model/model.py +++ /dev/null @@ -1,157 +0,0 @@ -import torch -import torch.nn as nn -from utils.poolers import Pooling -from .loss_func import sce_loss -from .gat import GAT -from .rnn import RNN_Cells -from utils.utils import create_norm -from functools import partial - - -class STGNN_AutoEncoder(nn.Module): - def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device, number_snapshot, activation, feat_drop, negative_slope, residual, norm, pooling, loss_fn="sce", alpha_l=2, use_all_hidden = True): - super(STGNN_AutoEncoder, self).__init__() - - #Initialize the encoder and decoder structure - self.encoder = STGNN(n_dim, e_dim, out_dim, out_dim, n_layers, n_heads, n_heads, number_snapshot, activation, feat_drop, negative_slope, residual, norm, True, use_all_hidden, device) - self.decoder = STGNN(out_dim, e_dim, out_dim, n_dim, 1, n_heads, 1, number_snapshot, activation, feat_drop, negative_slope, residual, norm, False, False, device) - - - # Linear layer for mapping encoder output to decoder input - if (use_all_hidden): - self.encoder_to_decoder = nn.Linear(n_layers * out_dim, out_dim, bias=False) - else: - self.encoder_to_decoder = nn.Linear(out_dim, out_dim, bias=False) - - # Additional components and parameters - self.n_layers = n_layers - self.pooler = Pooling(pooling) - self.number_snapshot = number_snapshot - self.use_all_hidden = use_all_hidden - self.device = device - self.criterion = self.setup_loss_fn(loss_fn, alpha_l) - - def setup_loss_fn(self, loss_fn, alpha_l): - if loss_fn == "sce": - criterion = partial(sce_loss, alpha=alpha_l) - - elif loss_fn == "ce": - criterion = nn.CrossEntropyLoss() - elif loss_fn == "mse": - criterion = nn.MSELoss() - elif loss_fn == "mae": - criterion = nn.L1Loss() - else: - raise NotImplementedError - return criterion - - - def forward(self, g): - - # Encode input graphs - node_features = [] - new_t = [] - for G in g: - new_g = G.clone() - node_features.append(new_g.ndata['attr'].float()) - new_g.edata['attr'] = new_g.edata['attr'].float() - new_t.append(new_g) - final_embedding = self.encoder(new_t, node_features) - encoding = [] - if (self.use_all_hidden): - for i in range(len(g)): - conca = [final_embedding[j][i] for j in range(len(final_embedding))] - encoding.append(torch.cat(conca,dim=1)) - else: - encoding = final_embedding[0] - - node_features = [] - for encoded in encoding: - encoded = self.encoder_to_decoder(encoded) - node_features.append(encoded) - - reconstructed = self.decoder(new_t, node_features) - recon = reconstructed[0][-1] - x_init = g[0].ndata['attr'].float() - loss = self.criterion(recon, x_init) - - return loss - - def embed(self, g): - node_features= [] - for G in g: - node_features.append(G.ndata['attr'].float()) - - return self.encoder.embed(g, node_features) - - -class STGNN(nn.Module): - - def __init__(self, input_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, n_heads_out, n_snapshot, activation, feat_drop, negative_slope, residual, norm, encoding, use_all_hidden, device): - super(STGNN, self).__init__() - - if encoding: - out = out_dim // n_heads - hidden = out_dim // n_heads - else: - hidden = hidden_dim - out = out_dim - - self.gnn = GAT( - n_dim=input_dim, - e_dim=e_dim, - hidden_dim=hidden, - out_dim=out, - n_layers=n_layers, - n_heads=n_heads, - n_heads_out=n_heads_out, - concat_out=True, - activation=activation, - feat_drop=feat_drop, - attn_drop=0.0, - negative_slope=negative_slope, - residual=residual, - norm=create_norm(norm), - encoding=encoding, - ) - self.rnn = RNN_Cells(out_dim, out_dim, n_snapshot, device) - self.use_all_hidden = use_all_hidden - - def forward(self, G, node_features): - - embeddings = [] - for i in range(len(G)): - g = G[i] - if (self.use_all_hidden): - node_embedding, all_hidden = self.gnn(g, node_features[i], return_hidden = self.use_all_hidden) - embeddings.append(all_hidden) - n_iter = len(all_hidden) - - else: - embeddings.append(self.gnn(g, node_features[i], return_hidden = self.use_all_hidden)) - n_iter = 1 - - result = [] - for j in range(n_iter): - encoding = [] - - for embedding in embeddings : - if (self.use_all_hidden): - encoding.append(embedding[j]) - else: - encoding.append(embedding) - - result.append(self.rnn(encoding)) - - return result - - - def embed(self, G, node_features): - embeddings = [] - for i in range(len(G)): - g = G[i].clone() - g.edata['attr'] = g.edata['attr'].float() - embedding = self.gnn(g, node_features[i], return_hidden = False) - embeddings.append(embedding) - - return self.rnn(embeddings)[-1] \ No newline at end of file diff --git a/model/rnn.py b/model/rnn.py deleted file mode 100644 index 663634fbdfbba966894f74241faed35876657f5c..0000000000000000000000000000000000000000 --- a/model/rnn.py +++ /dev/null @@ -1,23 +0,0 @@ -from torch.nn import GRUCell -import torch.nn as nn - - -class RNN_Cells(nn.Module): - def __init__(self, input_dim, hidden_dim, n_cells, device) : - super(RNN_Cells, self).__init__() - self.cells = nn.ModuleList() - - for i in range(n_cells): - self.cells.append(GRUCell(input_dim, hidden_dim, device=device)) - - - def forward(self, inputs): - - results = [] - for i in range(len(self.cells)): - if (i == 0): - results.append(self.cells[i](inputs[i], inputs[i])) - else: - results.append(self.cells[i](inputs[i], results[i-1])) - - return results \ No newline at end of file diff --git a/model/train.py b/model/train.py new file mode 100755 index 0000000000000000000000000000000000000000..9ec20ff2f7ec83a695904dff381cb10191d72bfc --- /dev/null +++ b/model/train.py @@ -0,0 +1,28 @@ +import dgl +import numpy as np +from tqdm import tqdm +from utils.loaddata import transform_graph + + +def batch_level_train(model, graphs, train_loader, optimizer, max_epoch, device, n_dim=0, e_dim=0): + epoch_iter = tqdm(range(max_epoch)) + cpt = 1 + for epoch in epoch_iter: + model.train() + loss_list = [] + cpt = 0 + for _, batch in enumerate(train_loader): + cpt +=1 + print(cpt) + batch_g = [transform_graph(graphs[idx][0], n_dim, e_dim).to(device) for idx in batch] + batch_g = dgl.batch(batch_g) + model.train() + loss = model(batch_g) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_list.append(loss.item()) + del batch_g + epoch_iter.set_description(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f}") + return model diff --git a/model/train_entity.py b/model/train_entity.py deleted file mode 100644 index 58aa15140e6204fbc48d865036ea37eae251858a..0000000000000000000000000000000000000000 --- a/model/train_entity.py +++ /dev/null @@ -1,29 +0,0 @@ -from tqdm import tqdm -from utils.dataloader import load_entity_level_dataset -import torch - - - - -def entity_level_train(model, snapshot, optimizer, max_epoch, device, dataset_name, train_data): - - model = model.to(device) - model.train() - epoch_iter = tqdm(range(max_epoch)) - for epoch in epoch_iter: - epoch_loss = 0.0 - for i in train_data: - g = load_entity_level_dataset(dataset_name, 'train', i, snapshot, device) - model.train() - loss = model(g) - loss /= len(train_data) - optimizer.zero_grad() - epoch_loss += loss.item() - loss.backward() - optimizer.step() - - del g - epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}") - torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name)) - - return model \ No newline at end of file diff --git a/model/train_graph.py b/model/train_graph.py deleted file mode 100644 index 1ffbd9869fab466b92497c4815db704f9adc1727..0000000000000000000000000000000000000000 --- a/model/train_graph.py +++ /dev/null @@ -1,50 +0,0 @@ -import numpy as np -from tqdm import tqdm -import torch -from utils.dataloader import load_graph -import matplotlib.pyplot as plt -from model.eval import batch_level_evaluation - -def batch_level_train(model, train_loader, optimizer, max_epoch, device, n_dim, e_dim, dataset_name, validation=True): - - epoch_iter = tqdm(range(max_epoch)) - model.to(device) # Move model to GPU - n_epoch = 0 - validation_f1 = [] - loss_global = [] - for epoch in epoch_iter: - model.train() - loss_list = [] - for iter, batch in enumerate(train_loader): - batch_g = [load_graph(int(idx), dataset_name, device) for idx in batch] # Move data to GPU - model.train() - g = batch_g[0] - loss = model(g) - optimizer.zero_grad() - loss.backward() - optimizer.step() - loss_list.append(loss.item()) - del batch_g, g - - n_epoch +=1 - - if (validation): - validation_f1.append(batch_level_evaluation(model, model.pooler, device, ['knn'], dataset_name, n_dim, e_dim)[0]) - - loss_global.append(np.mean(loss_list)) - torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name)) - epoch_iter.set_description(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f}") - - if (validation): - plt.plot(list(range(n_epoch)), validation_f1, label='Graph 2', marker='o', linestyle='-') - plt.plot(list(range(n_epoch)), loss_global, label='Graph 2', marker='x', linestyle='--') - plt.xlabel('X-axis') - plt.ylabel('Y-axis') - plt.title('Two Graphs on the Same Plot') - - # Add a legend - plt.legend() - - # Display the plot - plt.show() - return model diff --git a/mpi_host_file b/mpi_host_file old mode 100644 new mode 100755 diff --git a/requirement.txt b/requirement.txt old mode 100644 new mode 100755 diff --git a/result/FedAvg-SC2.pt b/result/FedAvg-SC2.pt deleted file mode 100644 index 89411388841175b247e424a74463231d45fe1391..0000000000000000000000000000000000000000 Binary files a/result/FedAvg-SC2.pt and /dev/null differ diff --git a/result/FedAvg-Unicorn-Cadets.pt b/result/FedAvg-Unicorn-Cadets.pt deleted file mode 100644 index 161ba7a56e45a3d1c33bed20b12bf8564b50a8c8..0000000000000000000000000000000000000000 Binary files a/result/FedAvg-Unicorn-Cadets.pt and /dev/null differ diff --git a/result/FedAvg-cadets-e3.pt b/result/FedAvg-cadets-e3.pt deleted file mode 100644 index 4afaa90ee421747fa19d39673f9b63b06eb9804f..0000000000000000000000000000000000000000 Binary files a/result/FedAvg-cadets-e3.pt and /dev/null differ diff --git a/result/FedAvg-clearscope-e3.pt b/result/FedAvg-clearscope-e3.pt deleted file mode 100644 index 136dc3c1eef647a32d2189a8896551aa1b5df162..0000000000000000000000000000000000000000 Binary files a/result/FedAvg-clearscope-e3.pt and /dev/null differ diff --git a/result/FedAvg-streamspot.pt b/result/FedAvg-streamspot.pt deleted file mode 100644 index 82308fe0b86721ef8790eff5354ec0e74388d70e..0000000000000000000000000000000000000000 Binary files a/result/FedAvg-streamspot.pt and /dev/null differ diff --git a/result/FedAvg-theia-e3.pt b/result/FedAvg-theia-e3.pt deleted file mode 100644 index 5a995ad213ddd3883f7c8f000095fd6821410d2b..0000000000000000000000000000000000000000 Binary files a/result/FedAvg-theia-e3.pt and /dev/null differ diff --git a/result/FedAvg-theia.pt b/result/FedAvg-theia.pt old mode 100644 new mode 100755 diff --git a/result/FedAvg-trace-e3.pt b/result/FedAvg-trace-e3.pt deleted file mode 100644 index 4dca8e988169b150f8d1692ad751a5bb051e5609..0000000000000000000000000000000000000000 Binary files a/result/FedAvg-trace-e3.pt and /dev/null differ diff --git a/result/FedAvg-wget-long.pt b/result/FedAvg-wget-long.pt deleted file mode 100644 index 9e7504eb16446ab46f327bcc30f13d77b051f75e..0000000000000000000000000000000000000000 Binary files a/result/FedAvg-wget-long.pt and /dev/null differ diff --git a/result/FedAvg-wget.pt b/result/FedAvg-wget.pt deleted file mode 100644 index 51c0ecdfba6589f6d63433368fb10c5810705765..0000000000000000000000000000000000000000 Binary files a/result/FedAvg-wget.pt and /dev/null differ diff --git a/result/FedAvg_Streamspot-streamspot.pt b/result/FedAvg_Streamspot-streamspot.pt old mode 100644 new mode 100755 diff --git a/result/FedOpt-theia.pt b/result/FedOpt-theia.pt old mode 100644 new mode 100755 diff --git a/result/FedOpt_Streamspot.pt b/result/FedOpt_Streamspot.pt old mode 100644 new mode 100755 diff --git a/result/FedProx-theia.pt b/result/FedProx-theia.pt old mode 100644 new mode 100755 diff --git a/result/FedProx_Streamspot.pt b/result/FedProx_Streamspot.pt old mode 100644 new mode 100755 diff --git a/save_results/distance_save_cadets_FedAvg.pkl b/save_results/distance_save_cadets_FedAvg.pkl new file mode 100644 index 0000000000000000000000000000000000000000..870244cd32749e587a2b411dc1c13dee18d08781 Binary files /dev/null and b/save_results/distance_save_cadets_FedAvg.pkl differ diff --git a/save_results/distance_save_cadets_FedOpt.pkl b/save_results/distance_save_cadets_FedOpt.pkl new file mode 100644 index 0000000000000000000000000000000000000000..9bdff9888dbf1ee8f9fb45f255b54502b4b56d6c Binary files /dev/null and b/save_results/distance_save_cadets_FedOpt.pkl differ diff --git a/save_results/distance_save_cadets_FedProx.pkl b/save_results/distance_save_cadets_FedProx.pkl new file mode 100644 index 0000000000000000000000000000000000000000..8680ed4c84080861d938bc5ea620d9d8589fc39a Binary files /dev/null and b/save_results/distance_save_cadets_FedProx.pkl differ diff --git a/save_results/distance_save_theia_FedAvg.pkl b/save_results/distance_save_theia_FedAvg.pkl new file mode 100755 index 0000000000000000000000000000000000000000..d885a972d533697733d1c953742df203a5c183a9 Binary files /dev/null and b/save_results/distance_save_theia_FedAvg.pkl differ diff --git a/save_results/distance_save_theia_FedOpt.pkl b/save_results/distance_save_theia_FedOpt.pkl new file mode 100644 index 0000000000000000000000000000000000000000..3e84da9be1a793612b77573f16b80f4bba3f530b Binary files /dev/null and b/save_results/distance_save_theia_FedOpt.pkl differ diff --git a/save_results/distance_save_theia_FedProx.pkl b/save_results/distance_save_theia_FedProx.pkl new file mode 100644 index 0000000000000000000000000000000000000000..80e772d9f66424b552809cec77dae53a291f83e4 Binary files /dev/null and b/save_results/distance_save_theia_FedProx.pkl differ diff --git a/save_results/distance_save_trace_FedAvg.pkl b/save_results/distance_save_trace_FedAvg.pkl new file mode 100644 index 0000000000000000000000000000000000000000..12baf82b2236f724c1c9f65a3adb19838a693c4b Binary files /dev/null and b/save_results/distance_save_trace_FedAvg.pkl differ diff --git a/save_results/distance_save_trace_FedOpt.pkl b/save_results/distance_save_trace_FedOpt.pkl new file mode 100644 index 0000000000000000000000000000000000000000..6d90af670837a994ba274a3a3a2d09ea5b941d31 Binary files /dev/null and b/save_results/distance_save_trace_FedOpt.pkl differ diff --git a/save_results/distance_save_trace_FedProx.pkl b/save_results/distance_save_trace_FedProx.pkl new file mode 100644 index 0000000000000000000000000000000000000000..e88de86f4bcb805e0475a41d4de3bbfde5d7353e Binary files /dev/null and b/save_results/distance_save_trace_FedProx.pkl differ diff --git a/test.py b/test.py new file mode 100755 index 0000000000000000000000000000000000000000..c1f6ebe728019b94bfcd9513907c28a9d39eaa81 --- /dev/null +++ b/test.py @@ -0,0 +1,15 @@ +from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split + + +( + train_data_num, + val_data_num, + test_data_num, + train_data_global, + val_data_global, + test_data_global, + data_local_num_dict, + train_data_local_dict, + val_data_local_dict, + test_data_local_dict, + ) = load_partition_data(None, 3, "streamspot") diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..968181a1cfaa6de944115c62b2b187376c734502 --- /dev/null +++ b/train.py @@ -0,0 +1,90 @@ +import os +import random +import torch +import warnings +from tqdm import tqdm +from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata, transform_graph +from model.autoencoder import build_model +from torch.utils.data.sampler import SubsetRandomSampler +from dgl.dataloading import GraphDataLoader +import dgl +from model.train import batch_level_train +from utils.utils import set_random_seed, create_optimizer +from utils.config import build_args +warnings.filterwarnings('ignore') + + +def extract_dataloaders(entries, batch_size): + random.shuffle(entries) + train_idx = torch.arange(len(entries)) + train_sampler = SubsetRandomSampler(train_idx) + train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler) + return train_loader + + +def main(main_args): + device = "cpu" + dataset_name = "trace" + if dataset_name == 'streamspot': + main_args.num_hidden = 256 + main_args.max_epoch = 5 + main_args.num_layers = 4 + elif dataset_name == 'wget': + main_args.num_hidden = 256 + main_args.max_epoch = 2 + main_args.num_layers = 4 + else: + main_args["num_hidden"] = 64 + main_args["max_epoch"] = 50 + main_args["num_layers"] = 3 + set_random_seed(0) + + if dataset_name == 'streamspot' or dataset_name == 'wget': + if dataset_name == 'streamspot': + batch_size = 12 + else: + batch_size = 1 + dataset = load_batch_level_dataset(dataset_name) + n_node_feat = dataset['n_feat'] + n_edge_feat = dataset['e_feat'] + graphs = dataset['dataset'] + train_index = dataset['train_index'] + main_args.n_dim = n_node_feat + main_args.e_dim = n_edge_feat + model = build_model(main_args) + model = model.to(device) + optimizer = create_optimizer(main_args.optimizer, model, main_args.lr, main_args.weight_decay) + model = batch_level_train(model, graphs, (extract_dataloaders(train_index, batch_size)), + optimizer, main_args.max_epoch, device, main_args.n_dim, main_args.e_dim) + torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name)) + else: + metadata = load_metadata(dataset_name) + main_args["n_dim"] = metadata['node_feature_dim'] + main_args["e_dim"] = metadata['edge_feature_dim'] + model = build_model(main_args) + model = model.to(device) + model.train() + optimizer = create_optimizer(main_args["optimizer"], model, main_args["lr"], main_args["weight_decay"]) + epoch_iter = tqdm(range(main_args["max_epoch"])) + n_train = metadata['n_train'] + for epoch in epoch_iter: + epoch_loss = 0.0 + for i in range(n_train): + g = load_entity_level_dataset(dataset_name, 'train', i).to(device) + model.train() + loss = model(g) + loss /= n_train + optimizer.zero_grad() + epoch_loss += loss.item() + loss.backward() + optimizer.step() + del g + epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}") + torch.save(model.state_dict(), "./result/checkpoint-{}.pt".format(dataset_name)) + + return + + +if __name__ == '__main__': + args = build_args() + main(args) diff --git a/trainer/__pycache__/magic_aggregator.cpython-310.pyc b/trainer/__pycache__/magic_aggregator.cpython-310.pyc deleted file mode 100644 index 173a82a3b45a20cb705440e95a4e6d09bf98930b..0000000000000000000000000000000000000000 Binary files a/trainer/__pycache__/magic_aggregator.cpython-310.pyc and /dev/null differ diff --git a/trainer/__pycache__/magic_aggregator.cpython-311.pyc b/trainer/__pycache__/magic_aggregator.cpython-311.pyc index e4c63a27702b466871bad2b0f18230d3620cd143..d9f8b782ee6ef9465d777c9e882fb180828f07ea 100644 Binary files a/trainer/__pycache__/magic_aggregator.cpython-311.pyc and b/trainer/__pycache__/magic_aggregator.cpython-311.pyc differ diff --git a/trainer/__pycache__/magic_trainer.cpython-310.pyc b/trainer/__pycache__/magic_trainer.cpython-310.pyc deleted file mode 100644 index 5aa1877df83f3deb6862f39988826b001e5273f8..0000000000000000000000000000000000000000 Binary files a/trainer/__pycache__/magic_trainer.cpython-310.pyc and /dev/null differ diff --git a/trainer/__pycache__/magic_trainer.cpython-311.pyc b/trainer/__pycache__/magic_trainer.cpython-311.pyc index a42939be3e8c7c9ffc0d3985c7b46efe4ab53d0f..a9ecabf7f8cc81c7ccd9f920ba01e5e3c9436ee9 100644 Binary files a/trainer/__pycache__/magic_trainer.cpython-311.pyc and b/trainer/__pycache__/magic_trainer.cpython-311.pyc differ diff --git a/trainer/__pycache__/single_trainer.cpython-311.pyc b/trainer/__pycache__/single_trainer.cpython-311.pyc index 4d343331aa188a304c811aafdd0d1814cc17efe9..5bfbd974b57fc9e9abe1ee459c27d16e5465c94d 100644 Binary files a/trainer/__pycache__/single_trainer.cpython-311.pyc and b/trainer/__pycache__/single_trainer.cpython-311.pyc differ diff --git a/trainer/magic_aggregator.py b/trainer/magic_aggregator.py old mode 100644 new mode 100755 index a5d09149bf53e9180f58802853aa4e6eca33c438..a9d7f70c78d64347b655494de696050e82d0fa6f --- a/trainer/magic_aggregator.py +++ b/trainer/magic_aggregator.py @@ -6,15 +6,11 @@ import wandb from sklearn.metrics import roc_auc_score, precision_recall_curve, auc from utils.config import build_args from fedml.core import ServerAggregator -from model.train_graph import batch_level_train -from model.train_entity import entity_level_train from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn - - -from utils.utils import set_random_seed, create_optimizer from utils.poolers import Pooling -from utils.config import build_args -from utils.dataloader import load_data, load_entity_level_dataset, load_metadata +# Trainer for MoleculeNet. The evaluation metric is ROC-AUC +from data.data_loader import load_batch_level_dataset_main, load_metadata, load_entity_level_dataset +from utils.loaddata import load_batch_level_dataset class MagicWgetAggregator(ServerAggregator): def __init__(self, model, args, name): @@ -66,30 +62,25 @@ class MagicWgetAggregator(ServerAggregator): logging.info("Models match perfectly! :)") def _test(self, test_data, device, args): - main_args = build_args() - dataset_name = self.name - nsnapshot = args.snapshot - if (self.name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']): - logging.info("----------test--------") - set_random_seed(0) - dataset = load_data(dataset_name) - n_node_feat = dataset['n_feat'] - n_edge_feat = dataset['e_feat'] - #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148] - main_args.n_dim = n_node_feat - main_args.e_dim = n_edge_feat + args = build_args() + if (self.name == 'wget' or self.name == 'streamspot'): + logging.info("----------test--------") + model = self.model model.eval() model.to(device) - model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device)) - model = model.to(device) - pooler = Pooling(main_args.pooling) - test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], dataset_name, main_args.n_dim, - main_args.e_dim) + pooler = Pooling(args["pooling"]) + dataset = load_batch_level_dataset(self.name) + n_node_feat = dataset['n_feat'] + n_edge_feat = dataset['e_feat'] + args["n_dim"] = n_node_feat + args["e_dim"] = n_edge_feat + test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"], args["e_dim"]) else: + torch.save(self.model.state_dict(), "./result/FedAvg-4client-{}.pt".format(self.name)) metadata = load_metadata(self.name) - main_args.n_dim = metadata['node_feature_dim'] - main_args.e_dim = metadata['edge_feature_dim'] + args["n_dim"] = metadata['node_feature_dim'] + args["e_dim"] = metadata['edge_feature_dim'] model = self.model.to(device) model.eval() malicious, _ = metadata['malicious'] @@ -99,17 +90,17 @@ class MagicWgetAggregator(ServerAggregator): with torch.no_grad(): x_train = [] for i in range(n_train): - g = load_entity_level_dataset(dataset_name, 'train', i, nsnapshot, device) + g = load_entity_level_dataset(self.name, 'train', i).to(device) x_train.append(model.embed(g).cpu().detach().numpy()) del g x_train = np.concatenate(x_train, axis=0) skip_benign = 0 x_test = [] for i in range(n_test): - g = load_entity_level_dataset(self.name, 'test', i, nsnapshot, device) + g = load_entity_level_dataset(self.name, 'test', i).to(device) # Exclude training samples from the test set if i != n_test - 1: - skip_benign += g[0].number_of_nodes() + skip_benign += g.number_of_nodes() x_test.append(model.embed(g).cpu().detach().numpy()) del g x_test = np.concatenate(x_test, axis=0) diff --git a/trainer/magic_trainer.py b/trainer/magic_trainer.py old mode 100644 new mode 100755 index 02c428c2459ec140dee5be82bbab64fcf0532c33..99c9a5b49e31a631f30053db0c3cf0623569b2eb --- a/trainer/magic_trainer.py +++ b/trainer/magic_trainer.py @@ -10,18 +10,18 @@ import wandb from sklearn.metrics import roc_auc_score, precision_recall_curve, auc from fedml.core import ClientTrainer +from model.autoencoder import build_model from torch.utils.data.sampler import SubsetRandomSampler from dgl.dataloading import GraphDataLoader - -from model.train_graph import batch_level_train -from model.train_entity import entity_level_train -from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn - - +from model.train import batch_level_train from utils.utils import set_random_seed, create_optimizer from utils.poolers import Pooling from utils.config import build_args -from utils.dataloader import load_data, load_entity_level_dataset, load_metadata +from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata +from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn +from model.autoencoder import build_model +from data.data_loader import transform_data +from utils.loaddata import load_batch_level_dataset # Trainer for MoleculeNet. The evaluation metric is ROC-AUC def extract_dataloaders(entries, batch_size): @@ -47,47 +47,77 @@ class MagicTrainer(ClientTrainer): self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): - main_args = build_args() - dataset_name = self.name - input('start') - if (dataset_name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']): - if (dataset_name == 'wget' or dataset_name == 'streamspot' or dataset_name == 'Unicorn-Cadets' or dataset_name == 'clearscope-e3'): - batch_size = 1 - main_args.max_epoch = 6 - - elif (dataset_name == 'SC2' or dataset_name == 'wget-long'): - batch_size = 1 - main_args.max_epoch = 1 - + test_data = None + args = build_args() + if (self.name == "wget"): + args["num_hidden"] = 256 + args["max_epoch"] = 2 + args["num_layers"] = 4 + batch_size = 1 + elif (self.name == "streamspot"): + args["num_hidden"] = 256 + args["max_epoch"] = 5 + args["num_layers"] = 4 + batch_size = 12 + else: + args["num_hidden"] = 64 + args["max_epoch"] = 50 + args["num_layers"] = 3 - dataset = load_data(dataset_name) + max_test_score = 0 + best_model_params = {} + if (self.name == 'wget' or self.name == 'streamspot'): + + dataset = load_batch_level_dataset(self.name) + data = transform_data(train_data) n_node_feat = dataset['n_feat'] n_edge_feat = dataset['e_feat'] - #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148] - validation_index = dataset['validation_index'] - label = dataset["labels"] - main_args.n_dim = n_node_feat - main_args.e_dim = n_edge_feat - main_args.optimizer = "adamw" - set_random_seed(0) - #model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device)) - optimizer = create_optimizer(main_args.optimizer, self.model, main_args.lr, main_args.weight_decay) - train_loader = extract_dataloaders(train_data[0], batch_size) - self.model = batch_level_train(self.model, train_loader, optimizer, main_args.max_epoch, device, main_args.n_dim, main_args.e_dim, dataset_name, validation= False) + graphs = data['dataset'] + train_index = data['train_index'] + args["n_dim"] = n_node_feat + args["e_dim"] = n_edge_feat + #self.model = build_model(args) + self.model = self.model.to(device) + optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"]) + self.model = batch_level_train(self.model, graphs, (extract_dataloaders(train_index, batch_size)), + optimizer, args["max_epoch"], device, n_node_feat, n_edge_feat) + test_score, _ = self.test(test_data, device, args) else: - main_args.max_epoch = 50 - nsnapshot = args.snapshot - main_args.optimizer = "adam" - set_random_seed(0) - #model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device)) - optimizer = create_optimizer(main_args.optimizer, self.model, main_args.lr, main_args.weight_decay) - #train_loader = extract_dataloaders(train_index, batch_size) - self.model = entity_level_train(self.model, nsnapshot, optimizer, main_args.max_epoch, device, dataset_name, train_data[0]) - - self.max = 0 - best_model_params = { + + metadata = load_metadata(self.name) + args["n_dim"] = metadata['node_feature_dim'] + args["e_dim"] = metadata['edge_feature_dim'] + self.model = self.model.to(device) + self.model.train() + optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"]) + epoch_iter = tqdm(range(args["max_epoch"])) + n_train = len(train_data[0]) + input("start?") + for epoch in epoch_iter: + epoch_loss = 0.0 + for i in range(n_train): + g = train_data[0][i] + self.model.train() + loss = self.model(g) + loss /= n_train + optimizer.zero_grad() + epoch_loss += loss.item() + loss.backward(retain_graph=True) + optimizer.step() + del g + epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}") + if (self.name == 'wget' or self.name == 'streamspot'): + test_score, _ = self.test(test_data, device, args) + if test_score > self.max: + self.max = test_score + best_model_params = { + k: v.cpu() for k, v in self.model.state_dict().items() + } + else: + self.max = 0 + best_model_params = { k: v.cpu() for k, v in self.model.state_dict().items() - } + } @@ -96,31 +126,17 @@ class MagicTrainer(ClientTrainer): def test(self, test_data, device, args): if (self.name == 'wget' or self.name == 'streamspot'): logging.info("----------test--------") - main_args = build_args() - dataset_name = self.name - if dataset_name in ['streamspot', 'wget', 'SC2']: - main_args.num_hidden = 256 - main_args.num_layers = 4 - main_args.max_epoch = 50 - else: - main_args.num_hidden = 64 - main_args.num_layers = 3 - set_random_seed(0) - dataset = load_data(dataset_name, 1, 0.6, 0.2) - n_node_feat = dataset['n_feat'] - n_edge_feat = dataset['e_feat'] - #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148] - main_args.n_dim = n_node_feat - main_args.e_dim = n_edge_feat + args = build_args() model = self.model model.eval() model.to(device) - model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device)) - model = model.to(device) - pooler = Pooling(main_args.pooling) - test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], dataset_name, main_args.n_dim, - main_args.e_dim) - + pooler = Pooling(args["pooling"]) + dataset = load_batch_level_dataset(self.name) + n_node_feat = dataset['n_feat'] + n_edge_feat = dataset['e_feat'] + args["n_dim"] = n_node_feat + args["e_dim"] = n_edge_feat + test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"], args["e_dim"]) else: metadata = load_metadata(self.name) args["n_dim"] = metadata['node_feature_dim'] diff --git a/trainer/single_trainer.py b/trainer/single_trainer.py old mode 100644 new mode 100755 diff --git a/trainer/utils/__pycache__/config.cpython-311.pyc b/trainer/utils/__pycache__/config.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..233d940b3ec7d7070e2ee889729ac7592675d3bf Binary files /dev/null and b/trainer/utils/__pycache__/config.cpython-311.pyc differ diff --git a/trainer/utils/config.py b/trainer/utils/config.py new file mode 100755 index 0000000000000000000000000000000000000000..bb9c1ef28d385b80c7673027b4f69b31c36b1175 --- /dev/null +++ b/trainer/utils/config.py @@ -0,0 +1,20 @@ +import argparse + + +def build_args(): + parser = argparse.ArgumentParser(description="MAGIC") + parser.add_argument("--dataset", type=str, default="wget") + parser.add_argument("--device", type=int, default=-1) + parser.add_argument("--lr", type=float, default=0.001, + help="learning rate") + parser.add_argument("--weight_decay", type=float, default=5e-4, + help="weight decay") + parser.add_argument("--negative_slope", type=float, default=0.2, + help="the negative slope of leaky relu for GAT") + parser.add_argument("--mask_rate", type=float, default=0.5) + parser.add_argument("--alpha_l", type=float, default=3, help="`pow`inddex for `sce` loss") + parser.add_argument("--optimizer", type=str, default="adam") + parser.add_argument("--loss_fn", type=str, default='sce') + parser.add_argument("--pooling", type=str, default="mean") + args = parser.parse_args() + return args diff --git a/trainer/utils/loaddata.py b/trainer/utils/loaddata.py new file mode 100755 index 0000000000000000000000000000000000000000..41e7dfc03ee39adcc1f5729d59aa21124d981fff --- /dev/null +++ b/trainer/utils/loaddata.py @@ -0,0 +1,197 @@ +import pickle as pkl +import time +import torch.nn.functional as F +import dgl +import networkx as nx +import json +from tqdm import tqdm +import os + + +class StreamspotDataset(dgl.data.DGLDataset): + def process(self): + pass + + def __init__(self, name): + super(StreamspotDataset, self).__init__(name=name) + if name == 'streamspot': + path = './data/streamspot' + num_graphs = 600 + self.graphs = [] + self.labels = [] + print('Loading {} dataset...'.format(name)) + for i in tqdm(range(num_graphs)): + idx = i + g = dgl.from_networkx( + nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))), + node_attrs=['type'], + edge_attrs=['type'] + ) + self.graphs.append(g) + if 300 <= idx <= 399: + self.labels.append(1) + else: + self.labels.append(0) + else: + raise NotImplementedError + + def __getitem__(self, i): + return self.graphs[i], self.labels[i] + + def __len__(self): + return len(self.graphs) + + +class WgetDataset(dgl.data.DGLDataset): + def process(self): + pass + + def __init__(self, name): + super(WgetDataset, self).__init__(name=name) + if name == 'wget': + path = './data/wget/final' + num_graphs = 150 + self.graphs = [] + self.labels = [] + print('Loading {} dataset...'.format(name)) + for i in tqdm(range(num_graphs)): + idx = i + g = dgl.from_networkx( + nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx))))), + node_attrs=['type'], + edge_attrs=['type'] + ) + self.graphs.append(g) + if 0 <= idx <= 24: + self.labels.append(1) + else: + self.labels.append(0) + else: + raise NotImplementedError + + def __getitem__(self, i): + return self.graphs[i], self.labels[i] + + def __len__(self): + return len(self.graphs) + + +def load_rawdata(name): + if name == 'streamspot': + path = './data/streamspot' + if os.path.exists(path + '/graphs.pkl'): + print('Loading processed {} dataset...'.format(name)) + raw_data = pkl.load(open(path + '/graphs.pkl', 'rb')) + else: + raw_data = StreamspotDataset(name) + pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb')) + elif name == 'wget': + path = './data/wget' + if os.path.exists(path + '/graphs.pkl'): + print('Loading processed {} dataset...'.format(name)) + raw_data = pkl.load(open(path + '/graphs.pkl', 'rb')) + else: + raw_data = WgetDataset(name) + pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb')) + else: + raise NotImplementedError + return raw_data + + +def load_batch_level_dataset(dataset_name): + dataset = load_rawdata(dataset_name) + graph, _ = dataset[0] + node_feature_dim = 0 + for g, _ in dataset: + node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item()) + edge_feature_dim = 0 + for g, _ in dataset: + edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item()) + node_feature_dim += 1 + edge_feature_dim += 1 + full_dataset = [i for i in range(len(dataset))] + train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0] + print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim)) + + return {'dataset': dataset, + 'train_index': train_dataset, + 'full_index': full_dataset, + 'n_feat': node_feature_dim, + 'e_feat': edge_feature_dim} + + +def transform_graph(g, node_feature_dim, edge_feature_dim): + new_g = g.clone() + new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float() + new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float() + return new_g + + +def preload_entity_level_dataset(path): + path = './data/' + path + if os.path.exists(path + '/metadata.json'): + pass + else: + print('transforming') + train_gs = [dgl.from_networkx( + nx.node_link_graph(g), + node_attrs=['type'], + edge_attrs=['type'] + ) for g in pkl.load(open(path + '/train.pkl', 'rb'))] + print('transforming') + test_gs = [dgl.from_networkx( + nx.node_link_graph(g), + node_attrs=['type'], + edge_attrs=['type'] + ) for g in pkl.load(open(path + '/test.pkl', 'rb'))] + malicious = pkl.load(open(path + '/malicious.pkl', 'rb')) + + node_feature_dim = 0 + for g in train_gs: + node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim) + for g in test_gs: + node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim) + node_feature_dim += 1 + edge_feature_dim = 0 + for g in train_gs: + edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim) + for g in test_gs: + edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim) + edge_feature_dim += 1 + result_test_gs = [] + for g in test_gs: + g = transform_graph(g, node_feature_dim, edge_feature_dim) + result_test_gs.append(g) + result_train_gs = [] + for g in train_gs: + g = transform_graph(g, node_feature_dim, edge_feature_dim) + result_train_gs.append(g) + metadata = { + 'node_feature_dim': node_feature_dim, + 'edge_feature_dim': edge_feature_dim, + 'malicious': malicious, + 'n_train': len(result_train_gs), + 'n_test': len(result_test_gs) + } + with open(path + '/metadata.json', 'w', encoding='utf-8') as f: + json.dump(metadata, f) + for i, g in enumerate(result_train_gs): + with open(path + '/train{}.pkl'.format(i), 'wb') as f: + pkl.dump(g, f) + for i, g in enumerate(result_test_gs): + with open(path + '/test{}.pkl'.format(i), 'wb') as f: + pkl.dump(g, f) + + +def load_metadata(path): + preload_entity_level_dataset(path) + with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f: + metadata = json.load(f) + return metadata + + +def load_entity_level_dataset(path, t, n): + preload_entity_level_dataset(path) + with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f: + data = pkl.load(f) + return data diff --git a/trainer/utils/poolers.py b/trainer/utils/poolers.py new file mode 100755 index 0000000000000000000000000000000000000000..4b9dc4aee1eac42f578952596a5e400adbd1e391 --- /dev/null +++ b/trainer/utils/poolers.py @@ -0,0 +1,43 @@ +import torch.nn as nn + + +class Pooling(nn.Module): + def __init__(self, pooler): + super(Pooling, self).__init__() + self.pooler = pooler + + def forward(self, graph, feat, t=None): + feat = feat + # Implement node type-specific pooling + with graph.local_scope(): + if t is None: + if self.pooler == 'mean': + return feat.mean(0, keepdim=True) + elif self.pooler == 'sum': + return feat.sum(0, keepdim=True) + elif self.pooler == 'max': + return feat.max(0, keepdim=True) + else: + raise NotImplementedError + elif isinstance(t, int): + mask = (graph.ndata['type'] == t) + if self.pooler == 'mean': + return feat[mask].mean(0, keepdim=True) + elif self.pooler == 'sum': + return feat[mask].sum(0, keepdim=True) + elif self.pooler == 'max': + return feat[mask].max(0, keepdim=True) + else: + raise NotImplementedError + else: + mask = (graph.ndata['type'] == t[0]) + for i in range(1, len(t)): + mask |= (graph.ndata['type'] == t[i]) + if self.pooler == 'mean': + return feat[mask].mean(0, keepdim=True) + elif self.pooler == 'sum': + return feat[mask].sum(0, keepdim=True) + elif self.pooler == 'max': + return feat[mask].max(0, keepdim=True) + else: + raise NotImplementedError diff --git a/trainer/utils/streamspot_parser.py b/trainer/utils/streamspot_parser.py new file mode 100755 index 0000000000000000000000000000000000000000..04438eb0f6336ae2bad2015ad6caf4919f48f9a3 --- /dev/null +++ b/trainer/utils/streamspot_parser.py @@ -0,0 +1,53 @@ +import networkx as nx +from tqdm import tqdm +import json +raw_path = '../data/streamspot/' + +NUM_GRAPHS = 600 +node_type_dict = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] +edge_type_dict = ['i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', + 'q', 't', 'u', 'v', 'w', 'y', 'z', 'A', 'C', 'D', 'E', 'G'] +node_type_set = set(node_type_dict) +edge_type_set = set(edge_type_dict) + +count_graph = 0 +with open(raw_path + 'all.tsv', 'r', encoding='utf-8') as f: + lines = f.readlines() + g = nx.DiGraph() + node_map = {} + count_node = 0 + for line in tqdm(lines): + src, src_type, dst, dst_type, etype, graph_id = line.strip('\n').split('\t') + graph_id = int(graph_id) + if src_type not in node_type_set or dst_type not in node_type_set: + continue + if etype not in edge_type_set: + continue + if graph_id != count_graph: + count_graph += 1 + for n in g.nodes(): + g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type']) + for e in g.edges(): + g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type']) + f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8') + json.dump(nx.node_link_data(g), f1) + assert graph_id == count_graph + g = nx.DiGraph() + count_node = 0 + if src not in node_map: + node_map[src] = count_node + g.add_node(count_node, type=src_type) + count_node += 1 + if dst not in node_map: + node_map[dst] = count_node + g.add_node(count_node, type=dst_type) + count_node += 1 + if not g.has_edge(node_map[src], node_map[dst]): + g.add_edge(node_map[src], node_map[dst], type=etype) + count_graph += 1 + for n in g.nodes(): + g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type']) + for e in g.edges(): + g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type']) + f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8') + json.dump(nx.node_link_data(g), f1) diff --git a/trainer/utils/trace_parser.py b/trainer/utils/trace_parser.py new file mode 100755 index 0000000000000000000000000000000000000000..76a91d424a3eea0bbec87ffb635a0d98b27d72f7 --- /dev/null +++ b/trainer/utils/trace_parser.py @@ -0,0 +1,288 @@ +import argparse +import json +import os +import random +import re + +from tqdm import tqdm +import networkx as nx +import pickle as pkl + + +node_type_dict = {} +edge_type_dict = {} +node_type_cnt = 0 +edge_type_cnt = 0 + +metadata = { + 'trace':{ + 'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'], + 'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4'] + }, + 'theia':{ + 'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'], + 'test': ['ta1-theia-e3-official-6r.json.8'] + }, + 'cadets':{ + 'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'], + 'test': ['ta1-cadets-e3-official-2.json'] + } +} + + +pattern_uuid = re.compile(r'uuid\":\"(.*?)\"') +pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') +pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') +pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') +pattern_type = re.compile(r'type\":\"(.*?)\"') +pattern_time = re.compile(r'timestampNanos\":(.*?),') +pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"') +pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"') +pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"') + +adversarial = 'None' + + +def read_single_graph(dataset, malicious, path, test=False): + global node_type_cnt, edge_type_cnt + g = nx.DiGraph() + print('converting {} ...'.format(path)) + path = '../data/{}/'.format(dataset) + path + '.txt' + f = open(path, 'r') + lines = [] + for l in f.readlines(): + split_line = l.split('\t') + src, src_type, dst, dst_type, edge_type, ts = split_line + ts = int(ts) + if not test: + if src in malicious or dst in malicious: + if src in malicious and src_type != 'MemoryObject': + continue + if dst in malicious and dst_type != 'MemoryObject': + continue + + if src_type not in node_type_dict: + node_type_dict[src_type] = node_type_cnt + node_type_cnt += 1 + if dst_type not in node_type_dict: + node_type_dict[dst_type] = node_type_cnt + node_type_cnt += 1 + if edge_type not in edge_type_dict: + edge_type_dict[edge_type] = edge_type_cnt + edge_type_cnt += 1 + if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type: + lines.append([dst, src, dst_type, src_type, edge_type, ts]) + else: + lines.append([src, dst, src_type, dst_type, edge_type, ts]) + lines.sort(key=lambda l: l[5]) + + node_map = {} + node_type_map = {} + node_cnt = 0 + node_list = [] + for l in lines: + src, dst, src_type, dst_type, edge_type = l[:5] + src_type_id = node_type_dict[src_type] + dst_type_id = node_type_dict[dst_type] + edge_type_id = edge_type_dict[edge_type] + if src not in node_map: + if test: + if adversarial == 'MFE' or adversarial == 'MCE': + if src in malicious: + src_type_id = node_type_dict['FILE_OBJECT_FILE'] + if not test: + if adversarial == 'BFP': + if src not in malicious: + i = random.randint(1, 20) + if i == 1: + src_type_id = node_type_dict['NetFlowObject'] + node_map[src] = node_cnt + g.add_node(node_cnt, type=src_type_id) + node_list.append(src) + node_type_map[src] = src_type + node_cnt += 1 + if dst not in node_map: + if test: + if adversarial == 'MFE' or adversarial == 'MCE': + if dst in malicious: + dst_type_id = node_type_dict['FILE_OBJECT_FILE'] + if not test: + if adversarial == 'BFP': + if dst not in malicious: + i = random.randint(1, 20) + if i == 1: + dst_type_id = node_type_dict['NetFlowObject'] + node_map[dst] = node_cnt + g.add_node(node_cnt, type=dst_type_id) + node_type_map[dst] = dst_type + node_list.append(dst) + node_cnt += 1 + if not g.has_edge(node_map[src], node_map[dst]): + g.add_edge(node_map[src], node_map[dst], type=edge_type_id) + if (adversarial == 'MSE' or adversarial == 'MCE') and test: + for i, node in enumerate(node_list): + if node in malicious: + while True: + another_node = random.choice(node_list) + if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS': + g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE']) + print(node, another_node) + break + if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE': + g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ']) + print(node, another_node) + break + if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS': + g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT']) + print(node, another_node) + break + if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject': + break + return node_map, g + + +def preprocess_dataset(dataset): + id_nodetype_map = {} + id_nodename_map = {} + for file in os.listdir('../data/{}/'.format(dataset)): + if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file: + print('reading {} ...'.format(file)) + f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8') + for line in tqdm(f): + if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue + if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue + if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue + if len(pattern_uuid.findall(line)) == 0: print(line) + uuid = pattern_uuid.findall(line)[0] + subject_type = pattern_type.findall(line) + + if len(subject_type) < 1: + if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line: + subject_type = 'MemoryObject' + if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line: + subject_type = 'NetFlowObject' + if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line: + subject_type = 'UnnamedPipeObject' + else: + subject_type = subject_type[0] + + if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']: + continue + id_nodetype_map[uuid] = subject_type + if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0: + id_nodename_map[uuid] = pattern_file_name.findall(line)[0] + elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0: + id_nodename_map[uuid] = pattern_process_name.findall(line)[0] + elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0: + id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0] + for key in metadata[dataset]: + for file in metadata[dataset][key]: + if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'): + continue + f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8') + fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8') + print('processing {} ...'.format(file)) + for line in tqdm(f): + if 'com.bbn.tc.schema.avro.cdm18.Event' in line: + edgeType = pattern_type.findall(line)[0] + timestamp = pattern_time.findall(line)[0] + srcId = pattern_src.findall(line) + + if len(srcId) == 0: continue + srcId = srcId[0] + if not srcId in id_nodetype_map: + continue + srcType = id_nodetype_map[srcId] + dstId1 = pattern_dst1.findall(line) + if len(dstId1) > 0 and dstId1[0] != 'null': + dstId1 = dstId1[0] + if not dstId1 in id_nodetype_map: + continue + dstType1 = id_nodetype_map[dstId1] + this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str( + dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n' + fw.write(this_edge1) + + dstId2 = pattern_dst2.findall(line) + if len(dstId2) > 0 and dstId2[0] != 'null': + dstId2 = dstId2[0] + if not dstId2 in id_nodetype_map.keys(): + continue + dstType2 = id_nodetype_map[dstId2] + this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str( + dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n' + fw.write(this_edge2) + fw.close() + f.close() + if len(id_nodename_map) != 0: + fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8') + json.dump(id_nodename_map, fw) + if len(id_nodetype_map) != 0: + fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8') + json.dump(id_nodetype_map, fw) + + +def read_graphs(dataset): + malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset) + f = open(malicious_entities, 'r') + malicious_entities = set() + for l in f.readlines(): + malicious_entities.add(l.lstrip().rstrip()) + + preprocess_dataset(dataset) + train_gs = [] + for file in metadata[dataset]['train']: + _, train_g = read_single_graph(dataset, malicious_entities, file, False) + train_gs.append(train_g) + test_gs = [] + test_node_map = {} + count_node = 0 + for file in metadata[dataset]['test']: + node_map, test_g = read_single_graph(dataset, malicious_entities, file, True) + assert len(node_map) == test_g.number_of_nodes() + test_gs.append(test_g) + for key in node_map: + if key not in test_node_map: + test_node_map[key] = node_map[key] + count_node + count_node += test_g.number_of_nodes() + + if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)): + with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f: + id_nodename_map = json.load(f) + with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f: + id_nodetype_map = json.load(f) + f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8') + final_malicious_entities = [] + malicious_names = [] + for e in malicious_entities: + if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject': + final_malicious_entities.append(test_node_map[e]) + if e in id_nodename_map: + malicious_names.append(id_nodename_map[e]) + f.write('{}\t{}\n'.format(e, id_nodename_map[e])) + else: + malicious_names.append(e) + f.write('{}\t{}\n'.format(e, e)) + else: + f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8') + final_malicious_entities = [] + malicious_names = [] + for e in malicious_entities: + if e in test_node_map: + final_malicious_entities.append(test_node_map[e]) + malicious_names.append(e) + f.write('{}\t{}\n'.format(e, e)) + + pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb')) + pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb')) + pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='CDM Parser') + parser.add_argument("--dataset", type=str, default="trace") + args = parser.parse_args() + if args.dataset not in ['trace', 'theia', 'cadets']: + raise NotImplementedError + read_graphs(args.dataset) + diff --git a/trainer/utils/utils.py b/trainer/utils/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..bcf9a481407696d73f30fe1dde279154a05702b3 --- /dev/null +++ b/trainer/utils/utils.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +from functools import partial +import numpy as np +import random +import torch.optim as optim + + +def create_optimizer(opt, model, lr, weight_decay): + opt_lower = opt.lower() + parameters = model.parameters() + opt_args = dict(lr=lr, weight_decay=weight_decay) + optimizer = None + opt_split = opt_lower.split("_") + opt_lower = opt_split[-1] + if opt_lower == "adam": + optimizer = optim.Adam(parameters, **opt_args) + elif opt_lower == "adamw": + optimizer = optim.AdamW(parameters, **opt_args) + elif opt_lower == "adadelta": + optimizer = optim.Adadelta(parameters, **opt_args) + elif opt_lower == "radam": + optimizer = optim.RAdam(parameters, **opt_args) + elif opt_lower == "sgd": + opt_args["momentum"] = 0.9 + return optim.SGD(parameters, **opt_args) + else: + assert False and "Invalid optimizer" + return optimizer + + +def random_shuffle(x, y): + idx = list(range(len(x))) + random.shuffle(idx) + return x[idx], y[idx] + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.determinstic = True + + +def create_activation(name): + if name == "relu": + return nn.ReLU() + elif name == "gelu": + return nn.GELU() + elif name == "prelu": + return nn.PReLU() + elif name is None: + return nn.Identity() + elif name == "elu": + return nn.ELU() + else: + raise NotImplementedError(f"{name} is not implemented.") + + +def create_norm(name): + if name == "layernorm": + return nn.LayerNorm + elif name == "batchnorm": + return nn.BatchNorm1d + elif name == "graphnorm": + return partial(NormLayer, norm_type="groupnorm") + else: + return None + + +class NormLayer(nn.Module): + def __init__(self, hidden_dim, norm_type): + super().__init__() + if norm_type == "batchnorm": + self.norm = nn.BatchNorm1d(hidden_dim) + elif norm_type == "layernorm": + self.norm = nn.LayerNorm(hidden_dim) + elif norm_type == "graphnorm": + self.norm = norm_type + self.weight = nn.Parameter(torch.ones(hidden_dim)) + self.bias = nn.Parameter(torch.zeros(hidden_dim)) + + self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) + else: + raise NotImplementedError + + def forward(self, graph, x): + tensor = x + if self.norm is not None and type(self.norm) != str: + return self.norm(tensor) + elif self.norm is None: + return tensor + + batch_list = graph.batch_num_nodes + batch_size = len(batch_list) + batch_list = torch.Tensor(batch_list).long().to(tensor.device) + batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) + batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor) + mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) + mean = mean.scatter_add_(0, batch_index, tensor) + mean = (mean.T / batch_list).T + mean = mean.repeat_interleave(batch_list, dim=0) + + sub = tensor - mean * self.mean_scale + + std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) + std = std.scatter_add_(0, batch_index, sub.pow(2)) + std = ((std.T / batch_list).T + 1e-6).sqrt() + std = std.repeat_interleave(batch_list, dim=0) + return self.weight * sub / std + self.bias diff --git a/trainer/utils/wget_parser.py b/trainer/utils/wget_parser.py new file mode 100755 index 0000000000000000000000000000000000000000..3ffbcffda1c5e9176cf1e5219b8c1e0ed1574055 --- /dev/null +++ b/trainer/utils/wget_parser.py @@ -0,0 +1,820 @@ +import argparse +import json +import os + +import xxhash +import tqdm +import logging +import networkx as nx +from tqdm import tqdm +import time +import datetime + + +valid_node_type = ['file', 'process_memory', 'task', 'mmaped_file', 'path', 'socket', 'address', 'link'] +CONSOLE_ARGUMENTS = None + + +def hashgen(l): + """Generate a single hash value from a list. @l is a list of + string values, which can be properties of a node/edge. This + function returns a single hashed integer value.""" + hasher = xxhash.xxh64() + for e in l: + hasher.update(e) + return hasher.intdigest() + + +def parse_nodes(json_string, node_map): + """Parse a CamFlow JSON string that may contain nodes ("activity" or "entity"). + Parsed nodes populate @node_map, which is a dictionary that maps the node's UID, + which is assigned by CamFlow to uniquely identify a node object, to a hashed + value (in str) which represents the 'type' of the node. """ + json_object = None + try: + # use "ignore" if non-decodeable exists in the @json_string + json_object = json.loads(json_string) + except Exception as e: + print("Exception ({}) occurred when parsing a node in JSON:".format(e)) + print(json_string) + exit(1) + if "activity" in json_object: + activity = json_object["activity"] + for uid in activity: + if not uid in node_map: # only parse unseen nodes + if "prov:type" not in activity[uid]: + # a node must have a type. + # record this issue if logging is turned on + if CONSOLE_ARGUMENTS.verbose: + logging.debug("skipping a problematic activity node with no 'prov:type': {}".format(uid)) + else: + node_map[uid] = activity[uid]["prov:type"] + + if "entity" in json_object: + entity = json_object["entity"] + for uid in entity: + if not uid in node_map: + if "prov:type" not in entity[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("skipping a problematic entity node with no 'prov:type': {}".format(uid)) + else: + node_map[uid] = entity[uid]["prov:type"] + + +def parse_all_nodes(filename, node_map): + """Parse all nodes in CamFlow data. @filename is the file path of + the CamFlow data to parse. @node_map contains the mappings of all + CamFlow nodes to their hashed attributes. """ + description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing nodes in CamFlow data from {}'.format(filename) + pb = tqdm(desc=description, mininterval=1.0, unit=" recs") + with open(filename, 'r') as f: + # each line in CamFlow data could contain multiple + # provenance nodes, we call @parse_nodes routine. + for line in f: + pb.update() # for progress tracking + parse_nodes(line, node_map) + f.close() + pb.close() + + +def parse_all_edges(inputfile, outputfile, node_map, noencode): + """Parse all edges (including their timestamp) from CamFlow data file @inputfile + to an @outputfile. Before this function is called, parse_all_nodes should be called + to populate the @node_map for all nodes in the CamFlow file. If @noencode is set, + we do not hash the nodes' original UUIDs generated by CamFlow to integers. This + function returns the total number of valid edge parsed from CamFlow dataset. + + The output edgelist has the following format for each line, if -s is not set: + <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp> + If -s is set, each line would look like: + <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>:<timestamp_stats>""" + total_edges = 0 + smallest_timestamp = None + # scan through the entire file to find the smallest timestamp from all the edges. + # this step is only needed if we need to add some statistical information. + if CONSOLE_ARGUMENTS.stats: + description = '\x1b[6;30;42m[STATUS]\x1b[0m Scanning edges in CamFlow data from {}'.format(inputfile) + pb = tqdm(desc=description, mininterval=1.0, unit=" recs") + with open(inputfile, 'r') as f: + for line in f: + pb.update() + json_object = json.loads(line) + + if "used" in json_object: + used = json_object["used"] + for uid in used: + if "prov:type" not in used[uid]: + continue + if "cf:date" not in used[uid]: + continue + if "prov:entity" not in used[uid]: + continue + if "prov:activity" not in used[uid]: + continue + srcUUID = used[uid]["prov:entity"] + dstUUID = used[uid]["prov:activity"] + if srcUUID not in node_map: + continue + if dstUUID not in node_map: + continue + timestamp_str = used[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + if smallest_timestamp == None or ts < smallest_timestamp: + smallest_timestamp = ts + + if "wasGeneratedBy" in json_object: + wasGeneratedBy = json_object["wasGeneratedBy"] + for uid in wasGeneratedBy: + if "prov:type" not in wasGeneratedBy[uid]: + continue + if "cf:date" not in wasGeneratedBy[uid]: + continue + if "prov:entity" not in wasGeneratedBy[uid]: + continue + if "prov:activity" not in wasGeneratedBy[uid]: + continue + srcUUID = wasGeneratedBy[uid]["prov:activity"] + dstUUID = wasGeneratedBy[uid]["prov:entity"] + if srcUUID not in node_map: + continue + if dstUUID not in node_map: + continue + timestamp_str = wasGeneratedBy[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + if smallest_timestamp == None or ts < smallest_timestamp: + smallest_timestamp = ts + + if "wasInformedBy" in json_object: + wasInformedBy = json_object["wasInformedBy"] + for uid in wasInformedBy: + if "prov:type" not in wasInformedBy[uid]: + continue + if "cf:date" not in wasInformedBy[uid]: + continue + if "prov:informant" not in wasInformedBy[uid]: + continue + if "prov:informed" not in wasInformedBy[uid]: + continue + srcUUID = wasInformedBy[uid]["prov:informant"] + dstUUID = wasInformedBy[uid]["prov:informed"] + if srcUUID not in node_map: + continue + if dstUUID not in node_map: + continue + timestamp_str = wasInformedBy[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + if smallest_timestamp == None or ts < smallest_timestamp: + smallest_timestamp = ts + + if "wasDerivedFrom" in json_object: + wasDerivedFrom = json_object["wasDerivedFrom"] + for uid in wasDerivedFrom: + if "prov:type" not in wasDerivedFrom[uid]: + continue + if "cf:date" not in wasDerivedFrom[uid]: + continue + if "prov:usedEntity" not in wasDerivedFrom[uid]: + continue + if "prov:generatedEntity" not in wasDerivedFrom[uid]: + continue + srcUUID = wasDerivedFrom[uid]["prov:usedEntity"] + dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"] + if srcUUID not in node_map: + continue + if dstUUID not in node_map: + continue + timestamp_str = wasDerivedFrom[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + if smallest_timestamp == None or ts < smallest_timestamp: + smallest_timestamp = ts + + if "wasAssociatedWith" in json_object: + wasAssociatedWith = json_object["wasAssociatedWith"] + for uid in wasAssociatedWith: + if "prov:type" not in wasAssociatedWith[uid]: + continue + if "cf:date" not in wasAssociatedWith[uid]: + continue + if "prov:agent" not in wasAssociatedWith[uid]: + continue + if "prov:activity" not in wasAssociatedWith[uid]: + continue + srcUUID = wasAssociatedWith[uid]["prov:agent"] + dstUUID = wasAssociatedWith[uid]["prov:activity"] + if srcUUID not in node_map: + continue + if dstUUID not in node_map: + continue + timestamp_str = wasAssociatedWith[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + if smallest_timestamp == None or ts < smallest_timestamp: + smallest_timestamp = ts + f.close() + pb.close() + + # we will go through the CamFlow data (again) and output edgelist to a file + output = open(outputfile, "w+") + description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in CamFlow data from {}'.format(inputfile) + pb = tqdm(desc=description, mininterval=1.0, unit=" recs") + with open(inputfile, 'r') as f: + for line in f: + pb.update() + json_object = json.loads(line) + + if "used" in json_object: + used = json_object["used"] + for uid in used: + if "prov:type" not in used[uid]: + # an edge must have a type; if not, + # we will have to skip the edge. Log + # this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (used) record without type: {}".format(uid)) + continue + else: + edgetype = "used" + # cf:id is used as logical timestamp to order edges + if "cf:id" not in used[uid]: + # an edge must have a logical timestamp; + # if not we will have to skip the edge. + # Log this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (used) record without logical timestamp: {}".format(uid)) + continue + else: + timestamp = used[uid]["cf:id"] + if "prov:entity" not in used[uid]: + # an edge's source node must exist; + # if not, we will have to skip the + # edge. Log this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug( + "edge (used/{}) record without source UUID: {}".format(used[uid]["prov:type"], uid)) + continue + if "prov:activity" not in used[uid]: + # an edge's destination node must exist; + # if not, we will have to skip the edge. + # Log this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug( + "edge (used/{}) record without destination UUID: {}".format(used[uid]["prov:type"], + uid)) + continue + srcUUID = used[uid]["prov:entity"] + dstUUID = used[uid]["prov:activity"] + # both source and destination node must + # exist in @node_map; if not, we will + # have to skip the edge. Log this issue + # if verbose is set. + if srcUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug( + "edge (used/{}) record with an unseen srcUUID: {}".format(used[uid]["prov:type"], uid)) + continue + else: + srcVal = node_map[srcUUID] + if dstUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug( + "edge (used/{}) record with an unseen dstUUID: {}".format(used[uid]["prov:type"], uid)) + continue + else: + dstVal = node_map[dstUUID] + if "cf:date" not in used[uid]: + # an edge must have a timestamp; if + # not, we will have to skip the edge. + # Log this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (used) record without timestamp: {}".format(uid)) + continue + else: + # we only record @adjusted_ts if we need + # to record stats of CamFlow dataset. + if CONSOLE_ARGUMENTS.stats: + ts_str = used[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + adjusted_ts = ts - smallest_timestamp + if "cf:jiffies" not in used[uid]: + # an edge must have a jiffies timestamp; if + # not, we will have to skip the edge. + # Log this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (used) record without jiffies: {}".format(uid)) + continue + else: + # we only record @jiffies if + # the option is set + if CONSOLE_ARGUMENTS.jiffies: + jiffies = used[uid]["cf:jiffies"] + total_edges += 1 + if noencode: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) + else: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp)) + + if "wasGeneratedBy" in json_object: + wasGeneratedBy = json_object["wasGeneratedBy"] + for uid in wasGeneratedBy: + if "prov:type" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy) record without type: {}".format(uid)) + continue + else: + edgetype = "wasGeneratedBy" + if "cf:id" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy) record without logical timestamp: {}".format(uid)) + continue + else: + timestamp = wasGeneratedBy[uid]["cf:id"] + if "prov:entity" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy/{}) record without source UUID: {}".format( + wasGeneratedBy[uid]["prov:type"], uid)) + continue + if "prov:activity" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy/{}) record without destination UUID: {}".format( + wasGeneratedBy[uid]["prov:type"], uid)) + continue + srcUUID = wasGeneratedBy[uid]["prov:activity"] + dstUUID = wasGeneratedBy[uid]["prov:entity"] + if srcUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy/{}) record with an unseen srcUUID: {}".format( + wasGeneratedBy[uid]["prov:type"], uid)) + continue + else: + srcVal = node_map[srcUUID] + if dstUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy/{}) record with an unsen dstUUID: {}".format( + wasGeneratedBy[uid]["prov:type"], uid)) + continue + else: + dstVal = node_map[dstUUID] + if "cf:date" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy) record without timestamp: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.stats: + ts_str = wasGeneratedBy[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + adjusted_ts = ts - smallest_timestamp + if "cf:jiffies" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy) record without jiffies: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.jiffies: + jiffies = wasGeneratedBy[uid]["cf:jiffies"] + total_edges += 1 + if noencode: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) + else: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, + edgetype, timestamp)) + + if "wasInformedBy" in json_object: + wasInformedBy = json_object["wasInformedBy"] + for uid in wasInformedBy: + if "prov:type" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy) record without type: {}".format(uid)) + continue + else: + edgetype = "wasInformedBy" + if "cf:id" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy) record without logical timestamp: {}".format(uid)) + continue + else: + timestamp = wasInformedBy[uid]["cf:id"] + if "prov:informant" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy/{}) record without source UUID: {}".format( + wasInformedBy[uid]["prov:type"], uid)) + continue + if "prov:informed" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy/{}) record without destination UUID: {}".format( + wasInformedBy[uid]["prov:type"], uid)) + continue + srcUUID = wasInformedBy[uid]["prov:informant"] + dstUUID = wasInformedBy[uid]["prov:informed"] + if srcUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy/{}) record with an unseen srcUUID: {}".format( + wasInformedBy[uid]["prov:type"], uid)) + continue + else: + srcVal = node_map[srcUUID] + if dstUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy/{}) record with an unseen dstUUID: {}".format( + wasInformedBy[uid]["prov:type"], uid)) + continue + else: + dstVal = node_map[dstUUID] + if "cf:date" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy) record without timestamp: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.stats: + ts_str = wasInformedBy[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + adjusted_ts = ts - smallest_timestamp + if "cf:jiffies" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy) record without jiffies: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.jiffies: + jiffies = wasInformedBy[uid]["cf:jiffies"] + total_edges += 1 + if noencode: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) + else: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, + edgetype, timestamp)) + + if "wasDerivedFrom" in json_object: + wasDerivedFrom = json_object["wasDerivedFrom"] + for uid in wasDerivedFrom: + if "prov:type" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom) record without type: {}".format(uid)) + continue + else: + edgetype = "wasDerivedFrom" + if "cf:id" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom) record without logical timestamp: {}".format(uid)) + continue + else: + timestamp = wasDerivedFrom[uid]["cf:id"] + if "prov:usedEntity" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom/{}) record without source UUID: {}".format( + wasDerivedFrom[uid]["prov:type"], uid)) + continue + if "prov:generatedEntity" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom/{}) record without destination UUID: {}".format( + wasDerivedFrom[uid]["prov:type"], uid)) + continue + srcUUID = wasDerivedFrom[uid]["prov:usedEntity"] + dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"] + if srcUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom/{}) record with an unseen srcUUID: {}".format( + wasDerivedFrom[uid]["prov:type"], uid)) + continue + else: + srcVal = node_map[srcUUID] + if dstUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom/{}) record with an unseen dstUUID: {}".format( + wasDerivedFrom[uid]["prov:type"], uid)) + continue + else: + dstVal = node_map[dstUUID] + if "cf:date" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom) record without timestamp: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.stats: + ts_str = wasDerivedFrom[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + adjusted_ts = ts - smallest_timestamp + if "cf:jiffies" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom) record without jiffies: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.jiffies: + jiffies = wasDerivedFrom[uid]["cf:jiffies"] + total_edges += 1 + if noencode: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) + else: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, + edgetype, timestamp)) + + if "wasAssociatedWith" in json_object: + wasAssociatedWith = json_object["wasAssociatedWith"] + for uid in wasAssociatedWith: + if "prov:type" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith) record without type: {}".format(uid)) + continue + else: + edgetype = "wasAssociatedWith" + if "cf:id" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith) record without logical timestamp: {}".format(uid)) + continue + else: + timestamp = wasAssociatedWith[uid]["cf:id"] + if "prov:agent" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith/{}) record without source UUID: {}".format( + wasAssociatedWith[uid]["prov:type"], uid)) + continue + if "prov:activity" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith/{}) record without destination UUID: {}".format( + wasAssociatedWith[uid]["prov:type"], uid)) + continue + srcUUID = wasAssociatedWith[uid]["prov:agent"] + dstUUID = wasAssociatedWith[uid]["prov:activity"] + if srcUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith/{}) record with an unseen srcUUID: {}".format( + wasAssociatedWith[uid]["prov:type"], uid)) + continue + else: + srcVal = node_map[srcUUID] + if dstUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith/{}) record with an unseen dstUUID: {}".format( + wasAssociatedWith[uid]["prov:type"], uid)) + continue + else: + dstVal = node_map[dstUUID] + if "cf:date" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith) record without timestamp: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.stats: + ts_str = wasAssociatedWith[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + adjusted_ts = ts - smallest_timestamp + if "cf:jiffies" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith) record without jiffies: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.jiffies: + jiffies = wasAssociatedWith[uid]["cf:jiffies"] + total_edges += 1 + if noencode: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) + else: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, + edgetype, timestamp)) + f.close() + output.close() + pb.close() + return total_edges + +def read_single_graph(file_name, threshold): + graph = [] + edge_cnt = 0 + with open(file_name, 'r') as f: + for line in f: + try: + edge = line.strip().split("\t") + new_edge = [edge[0], edge[1]] + attributes = edge[2].strip().split(":") + source_node_type = attributes[0] + destination_node_type = attributes[1] + edge_type = attributes[2] + edge_order = attributes[3] + + new_edge.append(source_node_type) + new_edge.append(destination_node_type) + new_edge.append(edge_type) + new_edge.append(edge_order) + graph.append(new_edge) + edge_cnt += 1 + except: + print("{}".format(line)) + f.close() + graph.sort(key=lambda e: e[5]) + if len(graph) <= threshold: + return graph + else: + return graph[:threshold] + + +def process_graph(name, threshold): + graph = read_single_graph(name, threshold) + result_graph = nx.DiGraph() + cnt = 0 + for num, edge in enumerate(graph): + cnt += 1 + src, dst, src_type, dst_type, edge_type = edge[:5] + if True:# src_type in valid_node_type and dst_type in valid_node_type: + if not result_graph.has_node(src): + result_graph.add_node(src, type=src_type) + if not result_graph.has_node(dst): + result_graph.add_node(dst, type=dst_type) + if not result_graph.has_edge(src, dst): + result_graph.add_edge(src, dst, type=edge_type) + if bidirection: + result_graph.add_edge(dst, src, type='reverse_{}'.format(edge_type)) + return cnt, result_graph + + +node_type_list = [] +edge_type_list = [] +node_type_dict = {} +edge_type_dict = {} + + +def format_graph(g, name): + new_g = nx.DiGraph() + node_map = {} + node_cnt = 0 + for n in g.nodes: + node_map[n] = node_cnt + new_g.add_node(node_cnt, type=g.nodes[n]['type']) + node_cnt += 1 + for e in g.edges: + new_g.add_edge(node_map[e[0]], node_map[e[1]], type=g.edges[e]['type']) + for n in new_g.nodes: + node_type = new_g.nodes[n]['type'] + if not node_type in node_type_dict: + node_type_list.append(node_type) + node_type_dict[node_type] = 1 + else: + node_type_dict[node_type] += 1 + for e in new_g.edges: + edge_type = new_g.edges[e]['type'] + if not edge_type in edge_type_dict: + edge_type_list.append(edge_type) + edge_type_dict[edge_type] = 1 + else: + edge_type_dict[edge_type] += 1 + for n in new_g.nodes: + new_g.nodes[n]['type'] = node_type_list.index(new_g.nodes[n]['type']) + for e in new_g.edges: + new_g.edges[e]['type'] = edge_type_list.index(new_g.edges[e]['type']) + with open('{}.json'.format(name), 'w', encoding='utf-8') as f: + json.dump(nx.node_link_data(new_g), f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Convert CamFlow JSON to Unicorn edgelist') + args = parser.parse_args() + args.stats = False + args.verbose = False + args.jiffies = False + args.input = '../data/wget/raw/' + args.output = '../data/wget/processed/' + args.final_output = '../data/wget/final/' + args.noencode = False + if not os.path.exists(args.input): + os.mkdir(args.input) + if not os.path.exists(args.output): + os.mkdir(args.output) + if not os.path.exists(args.final_output): + os.mkdir(args.final_output) + CONSOLE_ARGUMENTS = args + + if args.verbose: + logging.basicConfig(filename=args.log, level=logging.DEBUG) + cnt = 0 + for fname in os.listdir(args.input): + cnt += 1 + node_map = dict() + parse_all_nodes(args.input + '/{}'.format(fname), node_map) + total_edges = parse_all_edges(args.input + '/{}'.format(fname), args.output + '/{}.log'.format(cnt), node_map, + args.noencode) + if args.stats: + total_nodes = len(node_map) + stats = open(args.stats_file + '/{}.log'.format(cnt), "a+") + stats.write("{},{},{}\n".format(args.input + '/{}'.format(fname), total_nodes, total_edges)) + + bidirection = False + threshold = 10000000 # infinity + interaction_dict = [] + graph_cnt = 0 + result_graphs = [] + input = args.output + base = args.final_output + + line_cnt = 0 + for i in tqdm(range(cnt)): + single_cnt, result_graph = process_graph('{}{}.log'.format(input, i + 1), threshold) + format_graph(result_graph, '{}{}'.format(base, i)) + line_cnt += single_cnt + + print(line_cnt // 150) + print(len(node_type_list)) + print(node_type_dict) + print(len(edge_type_list)) + print(edge_type_dict) + diff --git a/utils/__pycache__/config.cpython-310.pyc b/utils/__pycache__/config.cpython-310.pyc deleted file mode 100644 index ca2a48894d4871479b958f92c0ae288f5c3e6cad..0000000000000000000000000000000000000000 Binary files a/utils/__pycache__/config.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/config.cpython-311.pyc b/utils/__pycache__/config.cpython-311.pyc old mode 100644 new mode 100755 index 0a2ca013c4c0826c94b817a15f5aeaadf3fbf021..55cf6b81b72fe73af01f3176b76fd0081827f1e8 Binary files a/utils/__pycache__/config.cpython-311.pyc and b/utils/__pycache__/config.cpython-311.pyc differ diff --git a/utils/__pycache__/configJupyter.cpython-311.pyc b/utils/__pycache__/configJupyter.cpython-311.pyc deleted file mode 100644 index 470b992f5214101b8dcf54596f455e87c14e0d9b..0000000000000000000000000000000000000000 Binary files a/utils/__pycache__/configJupyter.cpython-311.pyc and /dev/null differ diff --git a/utils/__pycache__/dataloader.cpython-310.pyc b/utils/__pycache__/dataloader.cpython-310.pyc deleted file mode 100644 index d87ef800e3349966c54f8aff0158c342c15ce957..0000000000000000000000000000000000000000 Binary files a/utils/__pycache__/dataloader.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/dataloader.cpython-311.pyc b/utils/__pycache__/dataloader.cpython-311.pyc deleted file mode 100644 index 435191a8fbab8cf94e7ec9ff76aa01f782a87157..0000000000000000000000000000000000000000 Binary files a/utils/__pycache__/dataloader.cpython-311.pyc and /dev/null differ diff --git a/utils/__pycache__/loaddata.cpython-311.pyc b/utils/__pycache__/loaddata.cpython-311.pyc old mode 100644 new mode 100755 index 516f605481c393aef9331a111ab596d8f327e0ee..6dadf00a8f9a10b0fd4c0f82c44e99ec4cd3ffa0 Binary files a/utils/__pycache__/loaddata.cpython-311.pyc and b/utils/__pycache__/loaddata.cpython-311.pyc differ diff --git a/utils/__pycache__/poolers.cpython-310.pyc b/utils/__pycache__/poolers.cpython-310.pyc deleted file mode 100644 index 441c02bbdd690fbfccc5acf786b9ded0f067bd8d..0000000000000000000000000000000000000000 Binary files a/utils/__pycache__/poolers.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/poolers.cpython-311.pyc b/utils/__pycache__/poolers.cpython-311.pyc old mode 100644 new mode 100755 index 4464f0cddb9bf9523f2ee1a43193ac41ebdb2f73..d761b00f03cf59c7afe8815f579b18fed2f7dd8f Binary files a/utils/__pycache__/poolers.cpython-311.pyc and b/utils/__pycache__/poolers.cpython-311.pyc differ diff --git a/utils/__pycache__/utils.cpython-310.pyc b/utils/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 10658e8f670d580962f3f3982605978f526561a7..0000000000000000000000000000000000000000 Binary files a/utils/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/utils.cpython-311.pyc b/utils/__pycache__/utils.cpython-311.pyc old mode 100644 new mode 100755 index 16c0ff7b920416d5844cb0abf8995663dd0bf741..f2d041d338d864c03c2baf1f746aaa8bba806202 Binary files a/utils/__pycache__/utils.cpython-311.pyc and b/utils/__pycache__/utils.cpython-311.pyc differ diff --git a/utils/config.py b/utils/config.py old mode 100644 new mode 100755 index 75b33c7e63bcbcab20114e07ab7075922c59417b..ee14920f4b2e36dfcd0d7fddca100572837ade20 --- a/utils/config.py +++ b/utils/config.py @@ -1,23 +1,13 @@ -import argparse - - def build_args(): - parser = argparse.ArgumentParser(description="MAGIC") - parser.add_argument("--device", type=int, default=-1) - parser.add_argument("--lr", type=float, default=0.0001, - help="learning rate") - parser.add_argument("--weight_decay", type=float, default=5e-4, - help="weight decay") - parser.add_argument("--negative_slope", type=float, default=0.2, - help="the negative slope of leaky relu for GAT") - parser.add_argument("--mask_rate", type=float, default=0.5) - parser.add_argument("--alpha_l", type=float, default=3, help="`pow`inddex for `sce` loss") - parser.add_argument("--optimizer", type=str, default="adam") - parser.add_argument("--loss_fn", type=str, default='sce') - parser.add_argument("--pooling", type=str, default="mean") - parser.add_argument("--run_id", type=str, default="fedgraphnn_cs_ch") - parser.add_argument("--cf", type=str, default="fedml_config.yaml") - parser.add_argument("--rank", type=int, default=0) - parser.add_argument("--role", type=str, default="server") - args = parser.parse_args() - return args \ No newline at end of file + args = {} + args["dataset"] ="wget" + args["device"]=-1 + args["lr"]=0.001 + args["weight_decay"]=5e-4 + args["negative_slope"]=0.2 + args["mask_rate"]=0.5 + args["alpha_l"]=3 + args["optimizer"]="adam" + args["loss_fn"]='sce' + args["pooling"] = "mean" + return args diff --git a/utils/configJupyter.py b/utils/configJupyter.py deleted file mode 100644 index 269779745a9c835b757ea725d201e11e5ee3e87b..0000000000000000000000000000000000000000 --- a/utils/configJupyter.py +++ /dev/null @@ -1,14 +0,0 @@ - -def build_args(): - - args = {} - args['device'] = -1 - args['lr'] = 0.01 - args['weight_decay'] = 5e-4 - args['negative_slope'] = 0.2 - args['mask_rate'] = 0.5 - args['alpha_l'] = 3 - args['optimizer'] = 'adam' - args['loss_fn'] = "sce" - args['pooling'] = "mean" - return args \ No newline at end of file diff --git a/utils/dataloader.py b/utils/dataloader.py deleted file mode 100644 index 992c6fcacaee228330db91e920c03c6d4f2fc8ad..0000000000000000000000000000000000000000 --- a/utils/dataloader.py +++ /dev/null @@ -1,331 +0,0 @@ -import os -import random -import networkx as nx -import dgl -import torch -import pickle as pkl -import json -import logging -import numpy as np - -path_dataset = 'D:/PFE DATASETS/' - -def darpa_split(name): - metadata = load_metadata(name) - n_train = metadata['n_train'] - train_dataset = range(n_train) - train_labels = [0]* n_train - - - return ( - train_dataset, - train_labels, - [], - [], - [], - [] - ) - - -def create_random_split(name, snapshots): - dataset = load_data(name) - # Random 80/10/10 split as suggested - - - all_idxs = list(range(len(dataset))) - random.shuffle(all_idxs) - - train_dataset = dataset['train_index'] - train_labels = [] - for id in train_dataset: - train_labels.append(dataset['labels'][id]) - - val_dataset = dataset['validation_index'] - val_labels = [] - for id in val_dataset: - val_labels.append(dataset['labels'][id]) - - test_dataset = dataset['test_index'] - test_labels = [] - for id in test_dataset: - test_labels.append(dataset['labels'][id]) - - - return ( - train_dataset, - train_labels, - val_dataset, - val_labels, - test_dataset, - test_labels, - ) - - - -def partition_data_by_sample_size( - client_number, name, snapshots -): - if (name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']): - ( - train_dataset, - train_labels, - val_dataset, - val_labels, - test_dataset, - test_labels, - ) = create_random_split(name, snapshots) - else: - ( - train_dataset, - train_labels, - val_dataset, - val_labels, - test_dataset, - test_labels, - ) = darpa_split(name) - - num_train_samples = len(train_dataset) - num_val_samples = len(val_dataset) - num_test_samples = len(test_dataset) - - train_idxs = list(range(num_train_samples)) - val_idxs = list(range(num_val_samples)) - test_idxs = list(range(num_test_samples)) - - random.shuffle(train_idxs) - random.shuffle(val_idxs) - random.shuffle(test_idxs) - - partition_dicts = [None] * client_number - - - clients_idxs_train = np.array_split(train_idxs, client_number) - clients_idxs_val = np.array_split(val_idxs, client_number) - clients_idxs_test = np.array_split(test_idxs, client_number) - - labels_of_all_clients = [] - for client in range(client_number): - client_train_idxs = clients_idxs_train[client] - client_val_idxs = clients_idxs_val[client] - client_test_idxs = clients_idxs_test[client] - - train_dataset_client = [ - train_dataset[idx] for idx in client_train_idxs - ] - train_labels_client = [train_labels[idx] for idx in client_train_idxs] - labels_of_all_clients.append(train_labels_client) - - val_dataset_client = [val_dataset[idx] for idx in client_val_idxs] - val_labels_client = [val_labels[idx] for idx in client_val_idxs] - - test_dataset_client = [test_dataset[idx] for idx in client_test_idxs] - test_labels_client = [test_labels[idx] for idx in client_test_idxs] - - - partition_dict = { - "train": train_dataset_client, - "val": val_dataset_client, - "test": test_dataset_client, - } - - partition_dicts[client] = partition_dict - global_data_dict = { - "train": train_dataset, - "val": val_dataset, - "test": test_dataset, - } - - return global_data_dict, partition_dicts - -def load_partition_data( - client_number, - name, - snapshots, - global_test=True, -): - global_data_dict, partition_dicts = partition_data_by_sample_size( - client_number, name, snapshots - ) - - data_local_num_dict = dict() - train_data_local_dict = dict() - val_data_local_dict = dict() - test_data_local_dict = dict() - - - - # IT IS VERY IMPORTANT THAT THE BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE. - train_data_global = global_data_dict["train"] - val_data_global = global_data_dict["val"] - test_data_global = global_data_dict["test"] - train_data_num = len(global_data_dict["train"]) - val_data_num = len(global_data_dict["val"]) - test_data_num = len(global_data_dict["test"]) - - for client in range(client_number): - train_dataset_client = partition_dicts[client]["train"] - val_dataset_client = partition_dicts[client]["val"] - test_dataset_client = partition_dicts[client]["test"] - - data_local_num_dict[client] = len(train_dataset_client) - train_data_local_dict[client] = train_dataset_client, - - val_data_local_dict[client] = val_dataset_client - - test_data_local_dict[client] = ( - test_data_global - if global_test - else test_dataset_client - - ) - - logging.info( - "Client idx = {}, local sample number = {}".format( - client, len(train_dataset_client) - ) - ) - - return ( - train_data_num, - val_data_num, - test_data_num, - train_data_global, - val_data_global, - test_data_global, - data_local_num_dict, - train_data_local_dict, - val_data_local_dict, - test_data_local_dict, - ) - - - - - - - -def preload_entity_level_dataset(name): - path = path_dataset + name - if os.path.exists(path + '/metadata.json'): - pass - else: - - malicious = pkl.load(open(path + '/malicious.pkl', 'rb')) - - n_train = len(os.listdir(path + '/train')) - n_test = len(os.listdir(path + '/test')) - - g = pkl.load(open(path + '/train/graph0/graph0.pkl', 'rb')) - - node_feature_dim = len(g.ndata['attr'][0]) - edge_feature_dim = len(g.edata['attr'][0]) - - metadata = { - 'node_feature_dim': node_feature_dim, - 'edge_feature_dim': edge_feature_dim, - 'malicious': malicious, - 'n_train': n_train, - 'n_test': n_test - } - with open(path + '/metadata.json', 'w', encoding='utf-8') as f: - json.dump(metadata, f) - - - -def load_metadata(name): - preload_entity_level_dataset(name) - with open( path_dataset + name + '/metadata.json', 'r', encoding='utf-8') as f: - metadata = json.load(f) - return metadata - - -def load_entity_level_dataset(name, t, n, snapshot, device): - preload_entity_level_dataset(name) - graphs = [] - for i in range(snapshot): - with open(path_dataset + name + '/' + t + '/graph{}/graph{}.pkl'.format(n, str(i)), 'rb') as f: - graphs.append(pkl.load(f).to(device)) - return graphs - - -def get_labels(name): - if (name=="wget" ): - return [1] * 25 + [0] * 125 - elif (name=="streamspot"): - return [0] * 300 + [1] * 100 + [0] * 200 - elif (name == 'SC2'): - return [0] * 125 + [1] * 25 - elif (name == 'Unicorn-Cadets'): - return [0] * 109 + [1] * 3 - elif (name == 'wget-long'): - return [0] * 100 + [1] * 5 - elif (name == 'clearscope-e3'): - return [0] * 44 + [1] * 50 - -def load_data(name): - if name == "wget": - n, n_dim, e_dim = 150, 14, 4 - full_dataset_index = list(range(n)) - train_dataset = list(range(50, 150)) - validation_dataset = list(range(50)) - test_dataset = list(range(50)) - elif name == "streamspot": - n, n_dim, e_dim = 600, 8, 26 - full_dataset_index = list(range(n)) - train_dataset = list(range(300)) - validation_dataset = list(range(300, 350)) + list(range(500,550)) - test_dataset = list(range(300, 400)) + list(range(400,500))+ list(range(500,600)) - elif name == 'SC2': - n_dim = len(pkl.load(open(path_dataset + 'SC2/node.pkl', 'rb')).keys()) - e_dim = len(pkl.load(open(path_dataset + 'SC2/edge.pkl', 'rb')).keys()) - n, full_dataset_index = 150, list(range(150)) - train_dataset = list(range(100)) - validation_dataset = list(range(100, 150)) - test_dataset = list(range(100, 150)) - elif name in ['Unicorn-Cadets', 'wget-long', 'clearscope-e3']: - n_dim = len(pkl.load(open(path_dataset + '{}/node.pkl'.format(name), 'rb')).keys()) - e_dim = len(pkl.load(open(path_dataset + '{}/edge.pkl'.format(name), 'rb')).keys()) - if name == 'Unicorn-Cadets': - n, train_dataset = 112, list(range(70)) - elif name == 'wget-long': - n, train_dataset = 105, list(range(70)) - else: - n, train_dataset = 94, list(range(30)) - full_dataset_index = list(range(n)) - validation_dataset = list(range(train_dataset[-1], n)) - test_dataset = validation_dataset - return {'dataset': full_dataset_index, - 'train_index': train_dataset, - 'test_index': test_dataset, - 'validation_index': validation_dataset, - 'full_index': full_dataset_index, - 'n_feat': n_dim, - 'e_feat': e_dim, - 'labels': get_labels(name)} - - - -def load_graph(id, name ,device): - graphs = [] - - if (name == "wget"): - path = path_dataset + 'wget/cache/' + 'graph{}'.format(str(id)) - elif (name == "streamspot"): - path = path_dataset + 'streamspot/cache/' + 'graph{}'.format(str(id)) - elif (name == "SC2"): - if (id < 125): path = path_dataset + 'SC2/cache/benign/' + 'graph{}'.format(str(id)) - else: path = path_dataset + 'SC2/cache/attack/' + 'graph{}'.format(str(id - 125)) - elif (name == 'Unicorn-Cadets'): - if (id < 109): path = path_dataset + 'Unicorn-Cadets/cache/benign/' + 'graph{}'.format(str(id)) - else: path = path_dataset + 'Unicorn-Cadets/cache/attack/' + 'graph{}'.format(str(id - 109)) - elif (name == 'wget-long'): - if (id < 100): path = path_dataset + 'wget-long/cache/benign/' + 'graph{}'.format(str(id)) - else: path = path_dataset + 'wget-long/cache/attack/' + 'graph{}'.format(str(id - 100)) - elif (name == 'clearscope-e3'): - if (id < 44): path = path_dataset + 'clearscope-e3/cache/benign/' + 'graph{}'.format(str(id)) - else: path = path_dataset + 'clearscope-e3/cache/attack/' + 'graph{}'.format(str(id - 44)) - - for fname in os.listdir(path): - graphs.append(pkl.load(open(path + '/' + fname, 'rb')).to(device)) - - return graphs \ No newline at end of file diff --git a/utils/loaddata.py b/utils/loaddata.py new file mode 100755 index 0000000000000000000000000000000000000000..ca48e0fd60aa69712f48c06f8083b37f679cf797 --- /dev/null +++ b/utils/loaddata.py @@ -0,0 +1,207 @@ +import pickle as pkl +import time +import torch.nn.functional as F +import dgl +import networkx as nx +import json +from tqdm import tqdm +import os + + + +class StreamspotDataset(dgl.data.DGLDataset): + def process(self): + pass + + def __init__(self, name): + super(StreamspotDataset, self).__init__(name=name) + if name == 'streamspot': + path = './data/streamspot' + num_graphs = 600 + self.graphs = [] + self.labels = [] + print('Loading {} dataset...'.format(name)) + for i in tqdm(range(num_graphs)): + idx = i + g = dgl.from_networkx( + nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))), + node_attrs=['type'], + edge_attrs=['type'] + ) + self.graphs.append(g) + if 300 <= idx <= 399: + self.labels.append(1) + else: + self.labels.append(0) + else: + raise NotImplementedError + + def __getitem__(self, i): + return self.graphs[i], self.labels[i] + + def __len__(self): + return len(self.graphs) + + +class WgetDataset(dgl.data.DGLDataset): + def process(self): + pass + + def __init__(self, name): + super(WgetDataset, self).__init__(name=name) + if name == 'wget': + pathattack = '/data/wget/finalattack' + pathbenin = 'data/wget/finalbenin' + num_graphs_benin = 125 + num_graphs_attack = 25 + self.graphs = [] + self.labels = [] + print('Loading {} dataset...'.format(name)) + for i in tqdm(range(num_graphs_benin)): + idx = i + g = dgl.from_networkx( + nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))), + node_attrs=['type'], + edge_attrs=['type'] + ) + self.graphs.append(g) + self.labels.append(0) + + for i in tqdm(range(num_graphs_attack)): + idx = i + g = dgl.from_networkx( + nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))), + node_attrs=['type'], + edge_attrs=['type'] + ) + self.graphs.append(g) + self.labels.append(1) + else: + raise NotImplementedError + + def __getitem__(self, i): + return self.graphs[i], self.labels[i] + + def __len__(self): + return len(self.graphs) + + +def load_rawdata(name): + if name == 'streamspot': + path = './data/streamspot' + if os.path.exists(path + '/graphs.pkl'): + print('Loading processed {} dataset...'.format(name)) + raw_data = pkl.load(open(path + '/graphs.pkl', 'rb')) + else: + raw_data = StreamspotDataset(name) + pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb')) + elif name == 'wget': + path = './data/wget' + if os.path.exists(path + '/graphs.pkl'): + print('Loading processed {} dataset...'.format(name)) + raw_data = pkl.load(open(path + '/graphs.pkl', 'rb')) + else: + raw_data = WgetDataset(name) + pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb')) + else: + raise NotImplementedError + return raw_data + + +def load_batch_level_dataset(dataset_name): + dataset = load_rawdata(dataset_name) + graph, _ = dataset[0] + node_feature_dim = 0 + for g, _ in dataset: + node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item()) + edge_feature_dim = 0 + for g, _ in dataset: + edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item()) + node_feature_dim += 1 + edge_feature_dim += 1 + full_dataset = [i for i in range(len(dataset))] + train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0] + print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim)) + + return {'dataset': dataset, + 'train_index': train_dataset, + 'full_index': full_dataset, + 'n_feat': node_feature_dim, + 'e_feat': edge_feature_dim} + + +def transform_graph(g, node_feature_dim, edge_feature_dim): + new_g = g.clone() + new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float() + new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float() + return new_g + + +def preload_entity_level_dataset(path): + path = './data/' + path + if os.path.exists(path + '/metadata.json'): + pass + else: + print('transforming') + train_gs = [dgl.from_networkx( + nx.node_link_graph(g), + node_attrs=['type'], + edge_attrs=['type'] + ) for g in pkl.load(open(path + '/train.pkl', 'rb'))] + print('transforming') + test_gs = [dgl.from_networkx( + nx.node_link_graph(g), + node_attrs=['type'], + edge_attrs=['type'] + ) for g in pkl.load(open(path + '/test.pkl', 'rb'))] + malicious = pkl.load(open(path + '/malicious.pkl', 'rb')) + + node_feature_dim = 0 + for g in train_gs: + node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim) + for g in test_gs: + node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim) + node_feature_dim += 1 + edge_feature_dim = 0 + for g in train_gs: + edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim) + for g in test_gs: + edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim) + edge_feature_dim += 1 + result_test_gs = [] + for g in test_gs: + g = transform_graph(g, node_feature_dim, edge_feature_dim) + result_test_gs.append(g) + result_train_gs = [] + for g in train_gs: + g = transform_graph(g, node_feature_dim, edge_feature_dim) + result_train_gs.append(g) + metadata = { + 'node_feature_dim': node_feature_dim, + 'edge_feature_dim': edge_feature_dim, + 'malicious': malicious, + 'n_train': len(result_train_gs), + 'n_test': len(result_test_gs) + } + with open(path + '/metadata.json', 'w', encoding='utf-8') as f: + json.dump(metadata, f) + for i, g in enumerate(result_train_gs): + with open(path + '/train{}.pkl'.format(i), 'wb') as f: + pkl.dump(g, f) + for i, g in enumerate(result_test_gs): + with open(path + '/test{}.pkl'.format(i), 'wb') as f: + pkl.dump(g, f) + + +def load_metadata(path): + preload_entity_level_dataset(path) + with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f: + metadata = json.load(f) + return metadata + + +def load_entity_level_dataset(path, t, n): + preload_entity_level_dataset(path) + with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f: + data = pkl.load(f) + return data diff --git a/utils/poolers.py b/utils/poolers.py old mode 100644 new mode 100755 diff --git a/utils/streamspot_parser.py b/utils/streamspot_parser.py new file mode 100755 index 0000000000000000000000000000000000000000..07687828ea5ecf56f9859955e599afb584826f1f --- /dev/null +++ b/utils/streamspot_parser.py @@ -0,0 +1,55 @@ +import networkx as nx +from tqdm import tqdm +import json +raw_path = '../data/streamspot/' + +NUM_GRAPHS = 600 +node_type_dict = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] +edge_type_dict = ['i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', + 'q', 't', 'u', 'v', 'w', 'y', 'z', 'A', 'C', 'D', 'E', 'G'] +node_type_set = set(node_type_dict) +edge_type_set = set(edge_type_dict) +print(raw_path) +count_graph = 0 +with open(raw_path + 'all.tsv', 'r', encoding='utf-8') as f: + print("Reading") + lines = f.readlines() + print("starting") + g = nx.DiGraph() + node_map = {} + count_node = 0 + for line in tqdm(lines): + src, src_type, dst, dst_type, etype, graph_id = line.strip('\n').split('\t') + graph_id = int(graph_id) + if src_type not in node_type_set or dst_type not in node_type_set: + continue + if etype not in edge_type_set: + continue + if graph_id != count_graph: + count_graph += 1 + for n in g.nodes(): + g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type']) + for e in g.edges(): + g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type']) + f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8') + json.dump(nx.node_link_data(g), f1) + assert graph_id == count_graph + g = nx.DiGraph() + count_node = 0 + if src not in node_map: + node_map[src] = count_node + g.add_node(count_node, type=src_type) + count_node += 1 + if dst not in node_map: + node_map[dst] = count_node + g.add_node(count_node, type=dst_type) + count_node += 1 + if not g.has_edge(node_map[src], node_map[dst]): + g.add_edge(node_map[src], node_map[dst], type=etype) + count_graph += 1 + for n in g.nodes(): + g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type']) + for e in g.edges(): + g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type']) + f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8') + json.dump(nx.node_link_data(g), f1) diff --git a/utils/trace_parser.py b/utils/trace_parser.py new file mode 100755 index 0000000000000000000000000000000000000000..76a91d424a3eea0bbec87ffb635a0d98b27d72f7 --- /dev/null +++ b/utils/trace_parser.py @@ -0,0 +1,288 @@ +import argparse +import json +import os +import random +import re + +from tqdm import tqdm +import networkx as nx +import pickle as pkl + + +node_type_dict = {} +edge_type_dict = {} +node_type_cnt = 0 +edge_type_cnt = 0 + +metadata = { + 'trace':{ + 'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'], + 'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4'] + }, + 'theia':{ + 'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'], + 'test': ['ta1-theia-e3-official-6r.json.8'] + }, + 'cadets':{ + 'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'], + 'test': ['ta1-cadets-e3-official-2.json'] + } +} + + +pattern_uuid = re.compile(r'uuid\":\"(.*?)\"') +pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') +pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') +pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}') +pattern_type = re.compile(r'type\":\"(.*?)\"') +pattern_time = re.compile(r'timestampNanos\":(.*?),') +pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"') +pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"') +pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"') + +adversarial = 'None' + + +def read_single_graph(dataset, malicious, path, test=False): + global node_type_cnt, edge_type_cnt + g = nx.DiGraph() + print('converting {} ...'.format(path)) + path = '../data/{}/'.format(dataset) + path + '.txt' + f = open(path, 'r') + lines = [] + for l in f.readlines(): + split_line = l.split('\t') + src, src_type, dst, dst_type, edge_type, ts = split_line + ts = int(ts) + if not test: + if src in malicious or dst in malicious: + if src in malicious and src_type != 'MemoryObject': + continue + if dst in malicious and dst_type != 'MemoryObject': + continue + + if src_type not in node_type_dict: + node_type_dict[src_type] = node_type_cnt + node_type_cnt += 1 + if dst_type not in node_type_dict: + node_type_dict[dst_type] = node_type_cnt + node_type_cnt += 1 + if edge_type not in edge_type_dict: + edge_type_dict[edge_type] = edge_type_cnt + edge_type_cnt += 1 + if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type: + lines.append([dst, src, dst_type, src_type, edge_type, ts]) + else: + lines.append([src, dst, src_type, dst_type, edge_type, ts]) + lines.sort(key=lambda l: l[5]) + + node_map = {} + node_type_map = {} + node_cnt = 0 + node_list = [] + for l in lines: + src, dst, src_type, dst_type, edge_type = l[:5] + src_type_id = node_type_dict[src_type] + dst_type_id = node_type_dict[dst_type] + edge_type_id = edge_type_dict[edge_type] + if src not in node_map: + if test: + if adversarial == 'MFE' or adversarial == 'MCE': + if src in malicious: + src_type_id = node_type_dict['FILE_OBJECT_FILE'] + if not test: + if adversarial == 'BFP': + if src not in malicious: + i = random.randint(1, 20) + if i == 1: + src_type_id = node_type_dict['NetFlowObject'] + node_map[src] = node_cnt + g.add_node(node_cnt, type=src_type_id) + node_list.append(src) + node_type_map[src] = src_type + node_cnt += 1 + if dst not in node_map: + if test: + if adversarial == 'MFE' or adversarial == 'MCE': + if dst in malicious: + dst_type_id = node_type_dict['FILE_OBJECT_FILE'] + if not test: + if adversarial == 'BFP': + if dst not in malicious: + i = random.randint(1, 20) + if i == 1: + dst_type_id = node_type_dict['NetFlowObject'] + node_map[dst] = node_cnt + g.add_node(node_cnt, type=dst_type_id) + node_type_map[dst] = dst_type + node_list.append(dst) + node_cnt += 1 + if not g.has_edge(node_map[src], node_map[dst]): + g.add_edge(node_map[src], node_map[dst], type=edge_type_id) + if (adversarial == 'MSE' or adversarial == 'MCE') and test: + for i, node in enumerate(node_list): + if node in malicious: + while True: + another_node = random.choice(node_list) + if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS': + g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE']) + print(node, another_node) + break + if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE': + g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ']) + print(node, another_node) + break + if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS': + g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT']) + print(node, another_node) + break + if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject': + break + return node_map, g + + +def preprocess_dataset(dataset): + id_nodetype_map = {} + id_nodename_map = {} + for file in os.listdir('../data/{}/'.format(dataset)): + if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file: + print('reading {} ...'.format(file)) + f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8') + for line in tqdm(f): + if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue + if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue + if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue + if len(pattern_uuid.findall(line)) == 0: print(line) + uuid = pattern_uuid.findall(line)[0] + subject_type = pattern_type.findall(line) + + if len(subject_type) < 1: + if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line: + subject_type = 'MemoryObject' + if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line: + subject_type = 'NetFlowObject' + if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line: + subject_type = 'UnnamedPipeObject' + else: + subject_type = subject_type[0] + + if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']: + continue + id_nodetype_map[uuid] = subject_type + if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0: + id_nodename_map[uuid] = pattern_file_name.findall(line)[0] + elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0: + id_nodename_map[uuid] = pattern_process_name.findall(line)[0] + elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0: + id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0] + for key in metadata[dataset]: + for file in metadata[dataset][key]: + if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'): + continue + f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8') + fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8') + print('processing {} ...'.format(file)) + for line in tqdm(f): + if 'com.bbn.tc.schema.avro.cdm18.Event' in line: + edgeType = pattern_type.findall(line)[0] + timestamp = pattern_time.findall(line)[0] + srcId = pattern_src.findall(line) + + if len(srcId) == 0: continue + srcId = srcId[0] + if not srcId in id_nodetype_map: + continue + srcType = id_nodetype_map[srcId] + dstId1 = pattern_dst1.findall(line) + if len(dstId1) > 0 and dstId1[0] != 'null': + dstId1 = dstId1[0] + if not dstId1 in id_nodetype_map: + continue + dstType1 = id_nodetype_map[dstId1] + this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str( + dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n' + fw.write(this_edge1) + + dstId2 = pattern_dst2.findall(line) + if len(dstId2) > 0 and dstId2[0] != 'null': + dstId2 = dstId2[0] + if not dstId2 in id_nodetype_map.keys(): + continue + dstType2 = id_nodetype_map[dstId2] + this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str( + dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n' + fw.write(this_edge2) + fw.close() + f.close() + if len(id_nodename_map) != 0: + fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8') + json.dump(id_nodename_map, fw) + if len(id_nodetype_map) != 0: + fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8') + json.dump(id_nodetype_map, fw) + + +def read_graphs(dataset): + malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset) + f = open(malicious_entities, 'r') + malicious_entities = set() + for l in f.readlines(): + malicious_entities.add(l.lstrip().rstrip()) + + preprocess_dataset(dataset) + train_gs = [] + for file in metadata[dataset]['train']: + _, train_g = read_single_graph(dataset, malicious_entities, file, False) + train_gs.append(train_g) + test_gs = [] + test_node_map = {} + count_node = 0 + for file in metadata[dataset]['test']: + node_map, test_g = read_single_graph(dataset, malicious_entities, file, True) + assert len(node_map) == test_g.number_of_nodes() + test_gs.append(test_g) + for key in node_map: + if key not in test_node_map: + test_node_map[key] = node_map[key] + count_node + count_node += test_g.number_of_nodes() + + if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)): + with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f: + id_nodename_map = json.load(f) + with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f: + id_nodetype_map = json.load(f) + f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8') + final_malicious_entities = [] + malicious_names = [] + for e in malicious_entities: + if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject': + final_malicious_entities.append(test_node_map[e]) + if e in id_nodename_map: + malicious_names.append(id_nodename_map[e]) + f.write('{}\t{}\n'.format(e, id_nodename_map[e])) + else: + malicious_names.append(e) + f.write('{}\t{}\n'.format(e, e)) + else: + f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8') + final_malicious_entities = [] + malicious_names = [] + for e in malicious_entities: + if e in test_node_map: + final_malicious_entities.append(test_node_map[e]) + malicious_names.append(e) + f.write('{}\t{}\n'.format(e, e)) + + pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb')) + pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb')) + pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='CDM Parser') + parser.add_argument("--dataset", type=str, default="trace") + args = parser.parse_args() + if args.dataset not in ['trace', 'theia', 'cadets']: + raise NotImplementedError + read_graphs(args.dataset) + diff --git a/utils/utils.py b/utils/utils.py old mode 100644 new mode 100755 diff --git a/utils/wget_parser.py b/utils/wget_parser.py new file mode 100755 index 0000000000000000000000000000000000000000..3ffbcffda1c5e9176cf1e5219b8c1e0ed1574055 --- /dev/null +++ b/utils/wget_parser.py @@ -0,0 +1,820 @@ +import argparse +import json +import os + +import xxhash +import tqdm +import logging +import networkx as nx +from tqdm import tqdm +import time +import datetime + + +valid_node_type = ['file', 'process_memory', 'task', 'mmaped_file', 'path', 'socket', 'address', 'link'] +CONSOLE_ARGUMENTS = None + + +def hashgen(l): + """Generate a single hash value from a list. @l is a list of + string values, which can be properties of a node/edge. This + function returns a single hashed integer value.""" + hasher = xxhash.xxh64() + for e in l: + hasher.update(e) + return hasher.intdigest() + + +def parse_nodes(json_string, node_map): + """Parse a CamFlow JSON string that may contain nodes ("activity" or "entity"). + Parsed nodes populate @node_map, which is a dictionary that maps the node's UID, + which is assigned by CamFlow to uniquely identify a node object, to a hashed + value (in str) which represents the 'type' of the node. """ + json_object = None + try: + # use "ignore" if non-decodeable exists in the @json_string + json_object = json.loads(json_string) + except Exception as e: + print("Exception ({}) occurred when parsing a node in JSON:".format(e)) + print(json_string) + exit(1) + if "activity" in json_object: + activity = json_object["activity"] + for uid in activity: + if not uid in node_map: # only parse unseen nodes + if "prov:type" not in activity[uid]: + # a node must have a type. + # record this issue if logging is turned on + if CONSOLE_ARGUMENTS.verbose: + logging.debug("skipping a problematic activity node with no 'prov:type': {}".format(uid)) + else: + node_map[uid] = activity[uid]["prov:type"] + + if "entity" in json_object: + entity = json_object["entity"] + for uid in entity: + if not uid in node_map: + if "prov:type" not in entity[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("skipping a problematic entity node with no 'prov:type': {}".format(uid)) + else: + node_map[uid] = entity[uid]["prov:type"] + + +def parse_all_nodes(filename, node_map): + """Parse all nodes in CamFlow data. @filename is the file path of + the CamFlow data to parse. @node_map contains the mappings of all + CamFlow nodes to their hashed attributes. """ + description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing nodes in CamFlow data from {}'.format(filename) + pb = tqdm(desc=description, mininterval=1.0, unit=" recs") + with open(filename, 'r') as f: + # each line in CamFlow data could contain multiple + # provenance nodes, we call @parse_nodes routine. + for line in f: + pb.update() # for progress tracking + parse_nodes(line, node_map) + f.close() + pb.close() + + +def parse_all_edges(inputfile, outputfile, node_map, noencode): + """Parse all edges (including their timestamp) from CamFlow data file @inputfile + to an @outputfile. Before this function is called, parse_all_nodes should be called + to populate the @node_map for all nodes in the CamFlow file. If @noencode is set, + we do not hash the nodes' original UUIDs generated by CamFlow to integers. This + function returns the total number of valid edge parsed from CamFlow dataset. + + The output edgelist has the following format for each line, if -s is not set: + <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp> + If -s is set, each line would look like: + <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>:<timestamp_stats>""" + total_edges = 0 + smallest_timestamp = None + # scan through the entire file to find the smallest timestamp from all the edges. + # this step is only needed if we need to add some statistical information. + if CONSOLE_ARGUMENTS.stats: + description = '\x1b[6;30;42m[STATUS]\x1b[0m Scanning edges in CamFlow data from {}'.format(inputfile) + pb = tqdm(desc=description, mininterval=1.0, unit=" recs") + with open(inputfile, 'r') as f: + for line in f: + pb.update() + json_object = json.loads(line) + + if "used" in json_object: + used = json_object["used"] + for uid in used: + if "prov:type" not in used[uid]: + continue + if "cf:date" not in used[uid]: + continue + if "prov:entity" not in used[uid]: + continue + if "prov:activity" not in used[uid]: + continue + srcUUID = used[uid]["prov:entity"] + dstUUID = used[uid]["prov:activity"] + if srcUUID not in node_map: + continue + if dstUUID not in node_map: + continue + timestamp_str = used[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + if smallest_timestamp == None or ts < smallest_timestamp: + smallest_timestamp = ts + + if "wasGeneratedBy" in json_object: + wasGeneratedBy = json_object["wasGeneratedBy"] + for uid in wasGeneratedBy: + if "prov:type" not in wasGeneratedBy[uid]: + continue + if "cf:date" not in wasGeneratedBy[uid]: + continue + if "prov:entity" not in wasGeneratedBy[uid]: + continue + if "prov:activity" not in wasGeneratedBy[uid]: + continue + srcUUID = wasGeneratedBy[uid]["prov:activity"] + dstUUID = wasGeneratedBy[uid]["prov:entity"] + if srcUUID not in node_map: + continue + if dstUUID not in node_map: + continue + timestamp_str = wasGeneratedBy[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + if smallest_timestamp == None or ts < smallest_timestamp: + smallest_timestamp = ts + + if "wasInformedBy" in json_object: + wasInformedBy = json_object["wasInformedBy"] + for uid in wasInformedBy: + if "prov:type" not in wasInformedBy[uid]: + continue + if "cf:date" not in wasInformedBy[uid]: + continue + if "prov:informant" not in wasInformedBy[uid]: + continue + if "prov:informed" not in wasInformedBy[uid]: + continue + srcUUID = wasInformedBy[uid]["prov:informant"] + dstUUID = wasInformedBy[uid]["prov:informed"] + if srcUUID not in node_map: + continue + if dstUUID not in node_map: + continue + timestamp_str = wasInformedBy[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + if smallest_timestamp == None or ts < smallest_timestamp: + smallest_timestamp = ts + + if "wasDerivedFrom" in json_object: + wasDerivedFrom = json_object["wasDerivedFrom"] + for uid in wasDerivedFrom: + if "prov:type" not in wasDerivedFrom[uid]: + continue + if "cf:date" not in wasDerivedFrom[uid]: + continue + if "prov:usedEntity" not in wasDerivedFrom[uid]: + continue + if "prov:generatedEntity" not in wasDerivedFrom[uid]: + continue + srcUUID = wasDerivedFrom[uid]["prov:usedEntity"] + dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"] + if srcUUID not in node_map: + continue + if dstUUID not in node_map: + continue + timestamp_str = wasDerivedFrom[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + if smallest_timestamp == None or ts < smallest_timestamp: + smallest_timestamp = ts + + if "wasAssociatedWith" in json_object: + wasAssociatedWith = json_object["wasAssociatedWith"] + for uid in wasAssociatedWith: + if "prov:type" not in wasAssociatedWith[uid]: + continue + if "cf:date" not in wasAssociatedWith[uid]: + continue + if "prov:agent" not in wasAssociatedWith[uid]: + continue + if "prov:activity" not in wasAssociatedWith[uid]: + continue + srcUUID = wasAssociatedWith[uid]["prov:agent"] + dstUUID = wasAssociatedWith[uid]["prov:activity"] + if srcUUID not in node_map: + continue + if dstUUID not in node_map: + continue + timestamp_str = wasAssociatedWith[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + if smallest_timestamp == None or ts < smallest_timestamp: + smallest_timestamp = ts + f.close() + pb.close() + + # we will go through the CamFlow data (again) and output edgelist to a file + output = open(outputfile, "w+") + description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in CamFlow data from {}'.format(inputfile) + pb = tqdm(desc=description, mininterval=1.0, unit=" recs") + with open(inputfile, 'r') as f: + for line in f: + pb.update() + json_object = json.loads(line) + + if "used" in json_object: + used = json_object["used"] + for uid in used: + if "prov:type" not in used[uid]: + # an edge must have a type; if not, + # we will have to skip the edge. Log + # this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (used) record without type: {}".format(uid)) + continue + else: + edgetype = "used" + # cf:id is used as logical timestamp to order edges + if "cf:id" not in used[uid]: + # an edge must have a logical timestamp; + # if not we will have to skip the edge. + # Log this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (used) record without logical timestamp: {}".format(uid)) + continue + else: + timestamp = used[uid]["cf:id"] + if "prov:entity" not in used[uid]: + # an edge's source node must exist; + # if not, we will have to skip the + # edge. Log this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug( + "edge (used/{}) record without source UUID: {}".format(used[uid]["prov:type"], uid)) + continue + if "prov:activity" not in used[uid]: + # an edge's destination node must exist; + # if not, we will have to skip the edge. + # Log this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug( + "edge (used/{}) record without destination UUID: {}".format(used[uid]["prov:type"], + uid)) + continue + srcUUID = used[uid]["prov:entity"] + dstUUID = used[uid]["prov:activity"] + # both source and destination node must + # exist in @node_map; if not, we will + # have to skip the edge. Log this issue + # if verbose is set. + if srcUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug( + "edge (used/{}) record with an unseen srcUUID: {}".format(used[uid]["prov:type"], uid)) + continue + else: + srcVal = node_map[srcUUID] + if dstUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug( + "edge (used/{}) record with an unseen dstUUID: {}".format(used[uid]["prov:type"], uid)) + continue + else: + dstVal = node_map[dstUUID] + if "cf:date" not in used[uid]: + # an edge must have a timestamp; if + # not, we will have to skip the edge. + # Log this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (used) record without timestamp: {}".format(uid)) + continue + else: + # we only record @adjusted_ts if we need + # to record stats of CamFlow dataset. + if CONSOLE_ARGUMENTS.stats: + ts_str = used[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + adjusted_ts = ts - smallest_timestamp + if "cf:jiffies" not in used[uid]: + # an edge must have a jiffies timestamp; if + # not, we will have to skip the edge. + # Log this issue if verbose is set. + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (used) record without jiffies: {}".format(uid)) + continue + else: + # we only record @jiffies if + # the option is set + if CONSOLE_ARGUMENTS.jiffies: + jiffies = used[uid]["cf:jiffies"] + total_edges += 1 + if noencode: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) + else: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp)) + + if "wasGeneratedBy" in json_object: + wasGeneratedBy = json_object["wasGeneratedBy"] + for uid in wasGeneratedBy: + if "prov:type" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy) record without type: {}".format(uid)) + continue + else: + edgetype = "wasGeneratedBy" + if "cf:id" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy) record without logical timestamp: {}".format(uid)) + continue + else: + timestamp = wasGeneratedBy[uid]["cf:id"] + if "prov:entity" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy/{}) record without source UUID: {}".format( + wasGeneratedBy[uid]["prov:type"], uid)) + continue + if "prov:activity" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy/{}) record without destination UUID: {}".format( + wasGeneratedBy[uid]["prov:type"], uid)) + continue + srcUUID = wasGeneratedBy[uid]["prov:activity"] + dstUUID = wasGeneratedBy[uid]["prov:entity"] + if srcUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy/{}) record with an unseen srcUUID: {}".format( + wasGeneratedBy[uid]["prov:type"], uid)) + continue + else: + srcVal = node_map[srcUUID] + if dstUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy/{}) record with an unsen dstUUID: {}".format( + wasGeneratedBy[uid]["prov:type"], uid)) + continue + else: + dstVal = node_map[dstUUID] + if "cf:date" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy) record without timestamp: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.stats: + ts_str = wasGeneratedBy[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + adjusted_ts = ts - smallest_timestamp + if "cf:jiffies" not in wasGeneratedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasGeneratedBy) record without jiffies: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.jiffies: + jiffies = wasGeneratedBy[uid]["cf:jiffies"] + total_edges += 1 + if noencode: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) + else: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, + edgetype, timestamp)) + + if "wasInformedBy" in json_object: + wasInformedBy = json_object["wasInformedBy"] + for uid in wasInformedBy: + if "prov:type" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy) record without type: {}".format(uid)) + continue + else: + edgetype = "wasInformedBy" + if "cf:id" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy) record without logical timestamp: {}".format(uid)) + continue + else: + timestamp = wasInformedBy[uid]["cf:id"] + if "prov:informant" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy/{}) record without source UUID: {}".format( + wasInformedBy[uid]["prov:type"], uid)) + continue + if "prov:informed" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy/{}) record without destination UUID: {}".format( + wasInformedBy[uid]["prov:type"], uid)) + continue + srcUUID = wasInformedBy[uid]["prov:informant"] + dstUUID = wasInformedBy[uid]["prov:informed"] + if srcUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy/{}) record with an unseen srcUUID: {}".format( + wasInformedBy[uid]["prov:type"], uid)) + continue + else: + srcVal = node_map[srcUUID] + if dstUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy/{}) record with an unseen dstUUID: {}".format( + wasInformedBy[uid]["prov:type"], uid)) + continue + else: + dstVal = node_map[dstUUID] + if "cf:date" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy) record without timestamp: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.stats: + ts_str = wasInformedBy[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + adjusted_ts = ts - smallest_timestamp + if "cf:jiffies" not in wasInformedBy[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasInformedBy) record without jiffies: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.jiffies: + jiffies = wasInformedBy[uid]["cf:jiffies"] + total_edges += 1 + if noencode: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) + else: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, + edgetype, timestamp)) + + if "wasDerivedFrom" in json_object: + wasDerivedFrom = json_object["wasDerivedFrom"] + for uid in wasDerivedFrom: + if "prov:type" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom) record without type: {}".format(uid)) + continue + else: + edgetype = "wasDerivedFrom" + if "cf:id" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom) record without logical timestamp: {}".format(uid)) + continue + else: + timestamp = wasDerivedFrom[uid]["cf:id"] + if "prov:usedEntity" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom/{}) record without source UUID: {}".format( + wasDerivedFrom[uid]["prov:type"], uid)) + continue + if "prov:generatedEntity" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom/{}) record without destination UUID: {}".format( + wasDerivedFrom[uid]["prov:type"], uid)) + continue + srcUUID = wasDerivedFrom[uid]["prov:usedEntity"] + dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"] + if srcUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom/{}) record with an unseen srcUUID: {}".format( + wasDerivedFrom[uid]["prov:type"], uid)) + continue + else: + srcVal = node_map[srcUUID] + if dstUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom/{}) record with an unseen dstUUID: {}".format( + wasDerivedFrom[uid]["prov:type"], uid)) + continue + else: + dstVal = node_map[dstUUID] + if "cf:date" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom) record without timestamp: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.stats: + ts_str = wasDerivedFrom[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + adjusted_ts = ts - smallest_timestamp + if "cf:jiffies" not in wasDerivedFrom[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasDerivedFrom) record without jiffies: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.jiffies: + jiffies = wasDerivedFrom[uid]["cf:jiffies"] + total_edges += 1 + if noencode: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) + else: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, + edgetype, timestamp)) + + if "wasAssociatedWith" in json_object: + wasAssociatedWith = json_object["wasAssociatedWith"] + for uid in wasAssociatedWith: + if "prov:type" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith) record without type: {}".format(uid)) + continue + else: + edgetype = "wasAssociatedWith" + if "cf:id" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith) record without logical timestamp: {}".format(uid)) + continue + else: + timestamp = wasAssociatedWith[uid]["cf:id"] + if "prov:agent" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith/{}) record without source UUID: {}".format( + wasAssociatedWith[uid]["prov:type"], uid)) + continue + if "prov:activity" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith/{}) record without destination UUID: {}".format( + wasAssociatedWith[uid]["prov:type"], uid)) + continue + srcUUID = wasAssociatedWith[uid]["prov:agent"] + dstUUID = wasAssociatedWith[uid]["prov:activity"] + if srcUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith/{}) record with an unseen srcUUID: {}".format( + wasAssociatedWith[uid]["prov:type"], uid)) + continue + else: + srcVal = node_map[srcUUID] + if dstUUID not in node_map: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith/{}) record with an unseen dstUUID: {}".format( + wasAssociatedWith[uid]["prov:type"], uid)) + continue + else: + dstVal = node_map[dstUUID] + if "cf:date" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith) record without timestamp: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.stats: + ts_str = wasAssociatedWith[uid]["cf:date"] + ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple()) + adjusted_ts = ts - smallest_timestamp + if "cf:jiffies" not in wasAssociatedWith[uid]: + if CONSOLE_ARGUMENTS.verbose: + logging.debug("edge (wasAssociatedWith) record without jiffies: {}".format(uid)) + continue + else: + if CONSOLE_ARGUMENTS.jiffies: + jiffies = wasAssociatedWith[uid]["cf:jiffies"] + total_edges += 1 + if noencode: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp)) + else: + if CONSOLE_ARGUMENTS.stats: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + adjusted_ts)) + elif CONSOLE_ARGUMENTS.jiffies: + output.write( + "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, + dstVal, edgetype, timestamp, + jiffies)) + else: + output.write( + "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, + edgetype, timestamp)) + f.close() + output.close() + pb.close() + return total_edges + +def read_single_graph(file_name, threshold): + graph = [] + edge_cnt = 0 + with open(file_name, 'r') as f: + for line in f: + try: + edge = line.strip().split("\t") + new_edge = [edge[0], edge[1]] + attributes = edge[2].strip().split(":") + source_node_type = attributes[0] + destination_node_type = attributes[1] + edge_type = attributes[2] + edge_order = attributes[3] + + new_edge.append(source_node_type) + new_edge.append(destination_node_type) + new_edge.append(edge_type) + new_edge.append(edge_order) + graph.append(new_edge) + edge_cnt += 1 + except: + print("{}".format(line)) + f.close() + graph.sort(key=lambda e: e[5]) + if len(graph) <= threshold: + return graph + else: + return graph[:threshold] + + +def process_graph(name, threshold): + graph = read_single_graph(name, threshold) + result_graph = nx.DiGraph() + cnt = 0 + for num, edge in enumerate(graph): + cnt += 1 + src, dst, src_type, dst_type, edge_type = edge[:5] + if True:# src_type in valid_node_type and dst_type in valid_node_type: + if not result_graph.has_node(src): + result_graph.add_node(src, type=src_type) + if not result_graph.has_node(dst): + result_graph.add_node(dst, type=dst_type) + if not result_graph.has_edge(src, dst): + result_graph.add_edge(src, dst, type=edge_type) + if bidirection: + result_graph.add_edge(dst, src, type='reverse_{}'.format(edge_type)) + return cnt, result_graph + + +node_type_list = [] +edge_type_list = [] +node_type_dict = {} +edge_type_dict = {} + + +def format_graph(g, name): + new_g = nx.DiGraph() + node_map = {} + node_cnt = 0 + for n in g.nodes: + node_map[n] = node_cnt + new_g.add_node(node_cnt, type=g.nodes[n]['type']) + node_cnt += 1 + for e in g.edges: + new_g.add_edge(node_map[e[0]], node_map[e[1]], type=g.edges[e]['type']) + for n in new_g.nodes: + node_type = new_g.nodes[n]['type'] + if not node_type in node_type_dict: + node_type_list.append(node_type) + node_type_dict[node_type] = 1 + else: + node_type_dict[node_type] += 1 + for e in new_g.edges: + edge_type = new_g.edges[e]['type'] + if not edge_type in edge_type_dict: + edge_type_list.append(edge_type) + edge_type_dict[edge_type] = 1 + else: + edge_type_dict[edge_type] += 1 + for n in new_g.nodes: + new_g.nodes[n]['type'] = node_type_list.index(new_g.nodes[n]['type']) + for e in new_g.edges: + new_g.edges[e]['type'] = edge_type_list.index(new_g.edges[e]['type']) + with open('{}.json'.format(name), 'w', encoding='utf-8') as f: + json.dump(nx.node_link_data(new_g), f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Convert CamFlow JSON to Unicorn edgelist') + args = parser.parse_args() + args.stats = False + args.verbose = False + args.jiffies = False + args.input = '../data/wget/raw/' + args.output = '../data/wget/processed/' + args.final_output = '../data/wget/final/' + args.noencode = False + if not os.path.exists(args.input): + os.mkdir(args.input) + if not os.path.exists(args.output): + os.mkdir(args.output) + if not os.path.exists(args.final_output): + os.mkdir(args.final_output) + CONSOLE_ARGUMENTS = args + + if args.verbose: + logging.basicConfig(filename=args.log, level=logging.DEBUG) + cnt = 0 + for fname in os.listdir(args.input): + cnt += 1 + node_map = dict() + parse_all_nodes(args.input + '/{}'.format(fname), node_map) + total_edges = parse_all_edges(args.input + '/{}'.format(fname), args.output + '/{}.log'.format(cnt), node_map, + args.noencode) + if args.stats: + total_nodes = len(node_map) + stats = open(args.stats_file + '/{}.log'.format(cnt), "a+") + stats.write("{},{},{}\n".format(args.input + '/{}'.format(fname), total_nodes, total_edges)) + + bidirection = False + threshold = 10000000 # infinity + interaction_dict = [] + graph_cnt = 0 + result_graphs = [] + input = args.output + base = args.final_output + + line_cnt = 0 + for i in tqdm(range(cnt)): + single_cnt, result_graph = process_graph('{}{}.log'.format(input, i + 1), threshold) + format_graph(result_graph, '{}{}'.format(base, i)) + line_cnt += single_cnt + + print(line_cnt // 150) + print(len(node_type_list)) + print(node_type_dict) + print(len(edge_type_list)) + print(edge_type_dict) +