diff --git a/README.md b/README.md
old mode 100644
new mode 100755
index 5b20bef384e8ddfdd2fe2d024b1c99016f1b056f..e4f79bf056df8e85813a38bb76a7f9e6539e186f
--- a/README.md
+++ b/README.md
@@ -1,93 +1,36 @@
-# GraphFL
+The command are used in an environnement that consist of ubuntu 22.04 with miniconda installed
 
+Original project: https://github.com/FDUDSDE/MAGIC
 
-
-## Getting started
-
-To make it easy for you to get started with GitLab, here's a list of recommended next steps.
-
-Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)!
-
-## Add your files
-
-- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
-- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command:
+First create the conda environnement for fedml with MPI support 
 
 ```
-cd existing_repo
-git remote add origin https://gitlab.liris.cnrs.fr/gladis/graphfl.git
-git branch -M main
-git push -uf origin main
+conda create --name fedml-pip python=3.8
+conda activate fedml-pip
+conda install --name fedml-pip pip
+conda install -c conda-forge mpi4py openmpi
+pip install "fedml[MPI]" 
 ```
 
-## Integrate with your tools
-
-- [ ] [Set up project integrations](https://gitlab.liris.cnrs.fr/gladis/graphfl/-/settings/integrations)
-
-## Collaborate with your team
-
-- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/)
-- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
-- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically)
-- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/)
-- [ ] [Set auto-merge](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html)
-
-## Test and Deploy
-
-Use the built-in continuous integration in GitLab.
-
-- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html)
-- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing (SAST)](https://docs.gitlab.com/ee/user/application_security/sast/)
-- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html)
-- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/)
-- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html)
-
-***
-
-# Editing this README
-
-When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thanks to [makeareadme.com](https://www.makeareadme.com/) for this template.
+Clone the MAGIC FedML project onto your current folder 
 
-## Suggestions for a good README
-
-Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information.
-
-## Name
-Choose a self-explaining name for your project.
-
-## Description
-Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors.
-
-## Badges
-On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge.
-
-## Visuals
-Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method.
-
-## Installation
-Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection.
-
-## Usage
-Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README.
+```
+git clone https://github.com/kamelferrahi/MAGIC_FEDERATED_FedML
+```
 
-## Support
-Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc.
+Install the necessary packages for Magic to run
 
-## Roadmap
-If you have ideas for releases in the future, it is a good idea to list them in the README.
+```
+conda install -c conda-forge aiohttp=3.9.1 aiosignal=1.3.1 anyio=4.2.0 attrdict=2.0.1 attrs=23.2.0 blis=0.7.11 boto3=1.34.12 botocore=1.34.12 brotli=1.1.0 catalogue=2.0.10 certifi=2023.11.17 chardet=5.2.0 charset-normalizer=3.3.2 click=8.1.7 cloudpathlib=0.16.0 confection=0.1.4 contourpy=1.2.0 cycler=0.12.1 cymem=2.0.8 dgl=1.1.3 dill=0.3.7 fastapi=0.92.0 fedml=0.8.13.post2 filelock=3.13.1 fonttools=4.47.0 frozenlist=1.4.1 fsspec=2023.12.2 gensim=4.3.2 gevent=23.9.1 geventhttpclient=2.0.9 gitdb=4.0.11 GitPython=3.1.40 GPUtil=1.4.0 graphviz=0.8.4 greenlet=3.0.3 h11=0.14.0 h5py=3.10.0 httpcore=1.0.2 httpx=0.26.0 idna=3.6 Jinja2=3.1.2 jmespath=1.0.1 joblib=1.3.2 kiwisolver=1.4.5 langcodes=3.3.0 MarkupSafe=2.1.3 matplotlib=3.8.2 mpi4py=3.1.3 mpmath=1.3.0 multidict=6.0.4 multiprocess=0.70.15 murmurhash=1.0.10 networkx=2.8.8 ntplib=0.4.0 numpy=1.26.3 nvidia-cublas-cu12=12.1.3.1 nvidia-cuda-cupti-cu12=12.1.105 nvidia-cuda-nvrtc-cu12=12.1.105 nvidia-cuda-runtime-cu12=12.1.105 nvidia-cudnn-cu12=8.9.2.26 nvidia-cufft-cu12=11.0.2.54 nvidia-curand-cu12=10.3.2.106 nvidia-cusolver-cu12=11.4.5.107 nvidia-cusparse-cu12=12.1.0.106 nvidia-nccl-cu12=2.18.1 nvidia-nvjitlink-cu12=12.3.101 nvidia-nvtx-cu12=12.1.105 onnx=1.15.0 packaging=23.2 paho-mqtt=1.6.1 pandas=2.1.4 pathtools=0.1.2 pillow=10.2.0 preshed=3.0.9 prettytable=3.9.0 promise=2.3 protobuf=3.20.3 psutil=5.9.7 py-machineid=0.4.6 pydantic=1.10.13 pyparsing=3.1.1 python-dateutil=2.8.2 python-rapidjson=1.14 pytz=2023.3.post1 PyYAML=6.0.1 redis=5.0.1 requests=2.31.0 s3transfer=0.10.0 scikit-learn=1.3.2 scipy=1.11.4 sentry-sdk=1.39.1 setproctitle=1.3.3 shortuuid=1.0.11 six=1.16.0 smart-open=6.3.0 smmap=5.0.1 sniffio=1.3.0 spacy=3.7.2 spacy-legacy=3.0.12 spacy-loggers=1.0.5 SQLAlchemy=2.0.25 srsly=2.4.8 starlette=0.25.0 sympy=1.12 thinc=8.2.2 threadpoolctl=3.2.0 torch=2.1.2 torch-cluster=1.6.3 torch-scatter=2.1.2 torch-sparse=0.6.18 torch-spline-conv=1.2.2 torch_geometric=2.4.0 torchvision=0.16.2 tqdm=4.66.1 triton=2.1.0 tritonclient=2.41.0 typer=0.9.0 typing_extensions=4.9.0 tzdata=2023.4 tzlocal=5.2 urllib3=2.0.7 uvicorn=0.25.0 wandb=0.13.2 wasabi=1.1.2 wcwidth=0.2.12 weasel=0.3.4 websocket-client=1.7.0 wget=3.2 yarl=1.9.4 zope.event=5.0 zope.interface=6.1
+```
 
-## Contributing
-State if you are open to contributions and what your requirements are for accepting them.
+Finally run the federated algorithm using the mpi command you can change the federated algorithm in `fedml_config.yaml` 
 
-For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self.
+```
+hostname > mpi_host_file
+mpirun -np 4  -hostfile mpi_host_file --oversubscribe python main.py --cf fedml_config.yaml
+```
 
-You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser.
 
-## Authors and acknowledgment
-Show your appreciation to those who have contributed to the project.
 
-## License
-For open source projects, say how it is licensed.
 
-## Project status
-If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers.
diff --git a/a.yaml b/a.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..9f8ce6e7fa889de842853b12c3a723d56ab60899
--- /dev/null
+++ b/a.yaml
@@ -0,0 +1,83 @@
+common_args:
+  training_type: "cross_silo"
+  scenario: "horizontal"
+  using_mlops: false
+  config_version: release
+  name: "exp"
+  project: "runs/train"
+  exist_ok: false
+  random_seed: 0
+  
+  common_args:
+  training_type: "simulation"
+  random_seed: 0
+comm_args:
+  backend: "MPI"
+  is_mobile: 0
+
+
+data_args:
+  dataset: "clintox"
+  data_cache_dir: ~/fedgraphnn_data/
+  part_file:  ~/fedgraphnn_data/partition
+  partition_method: "hetero"
+  partition_alpha: 0.5
+
+model_args:
+  model: "graphsage"
+  hidden_size: 32
+  node_embedding_dim: 32
+  graph_embedding_dim: 64
+  readout_hidden_dim: 64
+  alpha: 0.2
+  num_heads: 2
+  dropout: 0.3
+  normalize_features: False
+  normalize_adjacency: False
+  sparse_adjacency: False
+  model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
+  global_model_file_path: "./model_file_cache/global_model.pt"
+
+environment_args:
+  bootstrap: config/bootstrap.sh
+
+train_args:
+  federated_optimizer: "FedAvg"
+  client_id_list: 
+  client_num_in_total: 3
+  client_num_per_round: 3
+  comm_round: 100
+  epochs: 5
+  batch_size: 64
+  client_optimizer: sgd
+  learning_rate: 0.03
+  weight_decay: 0.001
+  metric: "prc-auc"
+  server_optimizer: sgd
+  lr: 0.001
+  server_lr: 0.001
+  wd: 0.001
+  ci: 0
+  server_momentum: 0.9
+
+validation_args:
+  frequency_of_the_test: 1
+
+device_args:
+  worker_num: 3
+  using_gpu: false
+  gpu_mapping_file: config/gpu_mapping.yaml
+  gpu_mapping_key: mapping_fedgraphnn_sp
+
+comm_args:
+  backend: "MQTT_S3"
+  mqtt_config_path: config/mqtt_config.yaml
+  s3_config_path: config/s3_config.yaml
+
+
+tracking_args:
+   # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
+  enable_wandb: false
+  wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
+  wandb_project: fedml
+  wandb_name: fedml_torch_moleculenet
diff --git a/a.yml b/a.yml
new file mode 100755
index 0000000000000000000000000000000000000000..adfbd8733fcee48a46e87a24cc48cbb1a0241f87
--- /dev/null
+++ b/a.yml
@@ -0,0 +1,68 @@
+common_args:
+  training_type: "simulation"
+  random_seed: 0
+
+data_args:
+  dataset: "clintox"
+  data_cache_dir: ~/fedgraphnn_data/
+  part_file:  ~/fedgraphnn_data/partition
+  partition_method: "hetero"
+  partition_alpha: 0.5
+
+model_args:
+  model: "graphsage"
+  hidden_size: 32
+  node_embedding_dim: 32
+  graph_embedding_dim: 64
+  readout_hidden_dim: 64
+  alpha: 0.2
+  num_heads: 2
+  dropout: 0.3
+  normalize_features: False
+  normalize_adjacency: False
+  sparse_adjacency: False
+  model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
+  global_model_file_path: "./model_file_cache/global_model.pt"
+
+environment_args:
+  bootstrap: config/bootstrap.sh
+
+train_args:
+  federated_optimizer: "FedAvg"
+  client_id_list: 
+  client_num_in_total: 3
+  client_num_per_round: 3
+  comm_round: 100
+  epochs: 5
+  batch_size: 64
+  client_optimizer: sgd
+  learning_rate: 0.03
+  weight_decay: 0.001
+  metric: "prc-auc"
+  server_optimizer: sgd
+  lr: 0.001
+  server_lr: 0.001
+  wd: 0.001
+  ci: 0
+  server_momentum: 0.9
+
+validation_args:
+  frequency_of_the_test: 1
+
+device_args:
+  worker_num: 2
+  using_gpu: false
+  gpu_mapping_file: config/gpu_mapping.yaml
+  gpu_mapping_key: mapping_fedgraphnn_sp
+
+comm_args:
+  backend: "MPI"
+  is_mobile: 0
+
+
+tracking_args:
+   # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
+  enable_wandb: false
+  wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
+  wandb_project: fedml
+  wandb_name: fedml_torch_moleculenet
diff --git a/data/__pycache__/data_loader.cpython-311.pyc b/data/__pycache__/data_loader.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..b17492ecf9ba2a9c984303dae0e410e17b03af83
Binary files /dev/null and b/data/__pycache__/data_loader.cpython-311.pyc differ
diff --git a/data/cadets/graphs.zip b/data/cadets/graphs.zip
new file mode 100755
index 0000000000000000000000000000000000000000..7b014b2ccfa0594ff4dc0e742a2d8453aafeb26d
Binary files /dev/null and b/data/cadets/graphs.zip differ
diff --git a/data/data_loader.py b/data/data_loader.py
new file mode 100755
index 0000000000000000000000000000000000000000..f52db36d726a6d6bd742e77841eb3df28b2de34b
--- /dev/null
+++ b/data/data_loader.py
@@ -0,0 +1,332 @@
+import logging
+import pickle as pkl
+import random
+
+import torch.utils.data as data
+
+from fedml.core import partition_class_samples_with_dirichlet_distribution
+import dgl
+import networkx as nx
+import json
+from tqdm import tqdm
+import os
+import numpy as np
+from utils.loaddata import load_rawdata, load_batch_level_dataset, load_entity_level_dataset, load_metadata
+
+
+class WgetDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(WgetDataset, self).__init__(name=name)
+        if name == 'wget':
+            pathattack = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalattack'
+            pathbenin = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalbenin'
+            num_graphs_benin = 125
+            num_graphs_attack = 25
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs_benin)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                self.labels.append(0)
+                
+            for i in tqdm(range(num_graphs_attack)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                self.labels.append(1)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+def darpa_split(name):
+    device = "cpu"
+    path = './data/' + name + '/'
+    metadata = load_metadata(name)
+    n_train = metadata['n_train']
+    train_dataset = []
+    train_labels = []
+    for i in range(n_train):
+        g = load_entity_level_dataset(name, 'train', i).to(device)
+        train_dataset.append(g)
+        train_labels.append(0)
+    
+    return (
+        train_dataset,
+        train_labels,
+        [],
+        [],
+        [],
+        []
+    )
+       
+
+def create_random_split(name):
+    dataset = load_rawdata(name)
+    # Random 80/10/10 split as suggested 
+    train_range = (0, int(0.8 * len(dataset)))
+    val_range = (
+        int(0.8 * len(dataset)),
+        int(0.8 * len(dataset)) + int(0.1 * len(dataset)),
+    )
+    test_range = (
+        int(0.8 * len(dataset)) + int(0.1 * len(dataset)),
+        len(dataset),
+    )
+
+    all_idxs = list(range(len(dataset)))
+    random.shuffle(all_idxs)
+
+    train_dataset = [
+        dataset[all_idxs[i]] for i in range(train_range[0], train_range[1])
+    ]
+   
+    train_labels = [dataset[all_idxs[i]][1] for i in range(train_range[0], train_range[1])]
+
+    val_dataset = [
+        dataset[all_idxs[i]] for i in range(val_range[0], val_range[1])
+    ]
+    val_labels = [dataset[all_idxs[i]][1] for i in range(val_range[0], val_range[1])]
+
+    test_dataset  = [
+        dataset[all_idxs[i]] for i in range(test_range[0], test_range[1])
+    ]
+    test_labels = [dataset[all_idxs[i]][1] for i in range(test_range[0], test_range[1])]
+
+    return (
+        train_dataset,
+        train_labels,
+        val_dataset,
+        val_labels,
+        test_dataset,
+        test_labels,
+    )
+
+
+
+def partition_data_by_sample_size(
+    args, client_number, name, uniform=True, compact=True
+):
+    if (name == 'wget' or name == 'streamspot'):
+        (
+            train_dataset,
+            train_labels,
+            val_dataset,
+            val_labels,
+            test_dataset,
+            test_labels,
+        ) = create_random_split(name)
+    else:
+        (
+            train_dataset,
+            train_labels,
+            val_dataset,
+            val_labels,
+            test_dataset,
+            test_labels,
+        ) = darpa_split(name)
+        
+    num_train_samples = len(train_dataset)
+    num_val_samples = len(val_dataset)
+    num_test_samples = len(test_dataset)
+
+    train_idxs = list(range(num_train_samples))
+    val_idxs = list(range(num_val_samples))
+    test_idxs = list(range(num_test_samples))
+
+    random.shuffle(train_idxs)
+    random.shuffle(val_idxs)
+    random.shuffle(test_idxs)
+
+    partition_dicts = [None] * client_number
+
+    if uniform:
+        clients_idxs_train = np.array_split(train_idxs, client_number)
+        clients_idxs_val = np.array_split(val_idxs, client_number)
+        clients_idxs_test = np.array_split(test_idxs, client_number)
+    else:
+        clients_idxs_train = create_non_uniform_split(
+            args, train_idxs, client_number, True
+        )
+        clients_idxs_val = create_non_uniform_split(
+            args, val_idxs, client_number, False
+        )
+        clients_idxs_test = create_non_uniform_split(
+            args, test_idxs, client_number, False
+        )
+
+    labels_of_all_clients = []
+    for client in range(client_number):
+        client_train_idxs = clients_idxs_train[client]
+        client_val_idxs = clients_idxs_val[client]
+        client_test_idxs = clients_idxs_test[client]
+
+        train_dataset_client = [
+            train_dataset[idx] for idx in client_train_idxs
+        ]
+        train_labels_client = [train_labels[idx] for idx in client_train_idxs]
+        labels_of_all_clients.append(train_labels_client)
+
+        val_dataset_client = [val_dataset[idx] for idx in client_val_idxs]
+        val_labels_client = [val_labels[idx] for idx in client_val_idxs]
+
+        test_dataset_client = [test_dataset[idx] for idx in client_test_idxs]
+        test_labels_client = [test_labels[idx] for idx in client_test_idxs]
+
+
+        partition_dict = {
+            "train": train_dataset_client,
+            "val": val_dataset_client,
+            "test": test_dataset_client,
+        }
+
+        partition_dicts[client] = partition_dict
+    global_data_dict = {
+        "train": train_dataset,
+        "val": val_dataset,
+        "test": test_dataset,
+    }
+
+    return global_data_dict, partition_dicts
+
+def load_partition_data(
+    args,
+    client_number,
+    name,
+    uniform=True,
+    global_test=True,
+    compact=True,
+    normalize_features=False,
+    normalize_adj=False,
+):
+    global_data_dict, partition_dicts = partition_data_by_sample_size(
+        args, client_number, name, uniform, compact=compact
+    )
+
+    data_local_num_dict = dict()
+    train_data_local_dict = dict()
+    val_data_local_dict = dict()
+    test_data_local_dict = dict()
+
+  
+
+    # IT IS VERY IMPORTANT THAT THE BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE.
+    train_data_global =  global_data_dict["train"]
+    val_data_global = global_data_dict["val"]
+    test_data_global = global_data_dict["test"]
+    train_data_num = len(global_data_dict["train"])
+    val_data_num = len(global_data_dict["val"])
+    test_data_num = len(global_data_dict["test"])
+
+    for client in range(client_number):
+        train_dataset_client = partition_dicts[client]["train"]
+        val_dataset_client = partition_dicts[client]["val"]
+        test_dataset_client = partition_dicts[client]["test"]
+
+        data_local_num_dict[client] = len(train_dataset_client)
+        train_data_local_dict[client] = train_dataset_client,
+ 
+        val_data_local_dict[client] = val_dataset_client
+  
+        test_data_local_dict[client] = (
+            test_data_global
+            if global_test
+            else test_dataset_client
+                
+        )
+
+        logging.info(
+            "Client idx = {}, local sample number = {}".format(
+                client, len(train_dataset_client)
+            )
+        )
+
+    return (
+        train_data_num,
+        val_data_num,
+        test_data_num,
+        train_data_global,
+        val_data_global,
+        test_data_global,
+        data_local_num_dict,
+        train_data_local_dict,
+        val_data_local_dict,
+        test_data_local_dict,
+    )
+
+
+
+def load_batch_level_dataset_main(name):
+    dataset = get_data(name)
+    graph, _ = dataset[0]
+    node_feature_dim = 0
+    for g, _ in dataset:
+        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
+    edge_feature_dim = 0
+    for g, _ in dataset:
+        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
+    node_feature_dim += 1
+    edge_feature_dim += 1
+    full_dataset = [i for i in range(len(dataset))]
+    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
+    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
+
+    return {'dataset': dataset,
+            'train_index': train_dataset,
+            'full_index': full_dataset,
+            'n_feat': node_feature_dim,
+            'e_feat': edge_feature_dim}
+            
+            
+class GraphDataset(dgl.data.DGLDataset):
+    def __init__(self, graph_label_list):
+        super(GraphDataset, self).__init__(name="wget")
+        self.graph_label_list = graph_label_list
+
+    def __len__(self):
+        return len(self.graph_label_list)
+
+    def __getitem__(self, idx):
+        graph, label = self.graph_label_list[idx]
+        # Convert the graph to a DGLGraph to work with DGL
+        return graph, label
+
+            
+def transform_data(data):
+    dataset = GraphDataset(data[0])
+    graph, _ = dataset[0]
+    node_feature_dim = 0
+    for g, _ in dataset:
+        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
+    edge_feature_dim = 0
+    for g, _ in dataset:
+        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
+    node_feature_dim += 1
+    edge_feature_dim += 1
+    full_dataset = [i for i in range(len(dataset))]
+    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
+    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
+
+    return {'dataset': dataset,
+            'train_index': train_dataset,
+            'full_index': full_dataset,
+            'n_feat': node_feature_dim,
+            'e_feat': edge_feature_dim}
+
diff --git a/data/theia/graphs.zip b/data/theia/graphs.zip
new file mode 100755
index 0000000000000000000000000000000000000000..3a98e4af6ec79f7199a8d843ff734bdce4ac59f9
Binary files /dev/null and b/data/theia/graphs.zip differ
diff --git a/data/trace/graphs.zip b/data/trace/graphs.zip
new file mode 100755
index 0000000000000000000000000000000000000000..98fdf390d57c6cec2ee07134ef7d4fe25accac2f
Binary files /dev/null and b/data/trace/graphs.zip differ
diff --git a/distance_save_cadets.pkl b/distance_save_cadets.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..1f2f428d6e2e6ffb8afbfd0d0669c46cab8718c6
Binary files /dev/null and b/distance_save_cadets.pkl differ
diff --git a/eval.py b/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..90b43a51b70c52b44fc56b884b82f961fc793515
--- /dev/null
+++ b/eval.py
@@ -0,0 +1,90 @@
+import torch
+import warnings
+from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata
+from model.autoencoder import build_model
+from utils.poolers import Pooling
+from utils.utils import set_random_seed
+import numpy as np
+from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
+from utils.config import build_args
+warnings.filterwarnings('ignore')
+
+
+def main(main_args):
+    device = "cpu"
+    device = torch.device(device)
+    dataset_name = "trace"
+    if dataset_name in ['streamspot', 'wget']:
+        main_args.num_hidden = 256
+        main_args.num_layers = 4
+    else:
+        main_args["num_hidden"] = 64
+        main_args["num_layers"] = 3
+    set_random_seed(0)
+
+    if dataset_name == 'streamspot' or dataset_name == 'wget':
+        dataset = load_batch_level_dataset(dataset_name)
+        n_node_feat = dataset['n_feat']
+        n_edge_feat = dataset['e_feat']
+        main_args.n_dim = n_node_feat
+        main_args.e_dim = n_edge_feat
+        model = build_model(main_args)
+        model.load_state_dict(torch.load("./result/FedOpt-{}.pt".format(dataset_name), map_location=device))
+        model = model.to(device)
+        pooler = Pooling(main_args.pooling)
+        test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], args.dataset, main_args.n_dim,
+                                                    main_args.e_dim)
+    else:
+        metadata = load_metadata(dataset_name)
+        main_args["n_dim"] = metadata['node_feature_dim']
+        main_args["e_dim"] = metadata['edge_feature_dim']
+        model = build_model(main_args)
+        model.load_state_dict(torch.load("./result/checkpoint-{}.pt".format(dataset_name), map_location=device))
+        model = model.to(device)
+        model.eval()
+        malicious, _ = metadata['malicious']
+        n_train = metadata['n_train']
+        n_test = metadata['n_test']
+
+        with torch.no_grad():
+            x_train = []
+            for i in range(n_train):
+                g = load_entity_level_dataset(dataset_name, 'train', i).to(device)
+                x_train.append(model.embed(g).cpu().detach().numpy())
+                del g
+            x_train = np.concatenate(x_train, axis=0)
+            skip_benign = 0
+            x_test = []
+            for i in range(n_test):
+                g = load_entity_level_dataset(dataset_name, 'test', i).to(device)
+                # Exclude training samples from the test set
+                if i != n_test - 1:
+                    skip_benign += g.number_of_nodes()
+                x_test.append(model.embed(g).cpu().detach().numpy())
+                del g
+            x_test = np.concatenate(x_test, axis=0)
+
+            n = x_test.shape[0]
+            y_test = np.zeros(n)
+            y_test[malicious] = 1.0
+            malicious_dict = {}
+            for i, m in enumerate(malicious):
+                malicious_dict[m] = i
+
+            # Exclude training samples from the test set
+            test_idx = []
+            for i in range(x_test.shape[0]):
+                if i >= skip_benign or y_test[i] == 1.0:
+                    test_idx.append(i)
+            result_x_test = x_test[test_idx]
+            result_y_test = y_test[test_idx]
+            del x_test, y_test
+            test_auc, test_std, _, _ = evaluate_entity_level_using_knn(dataset_name, x_train, result_x_test,
+                                                                       result_y_test)
+    print(f"#Test_AUC: {test_auc:.4f}±{test_std:.4f}")
+    return
+
+
+if __name__ == '__main__':
+    args = build_args()
+    main(args)
diff --git a/eval_result/distance_save_cadets.pkl b/eval_result/distance_save_cadets.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..2bb4de8e00a64a22e53b1a5187821e6e53dfffcc
Binary files /dev/null and b/eval_result/distance_save_cadets.pkl differ
diff --git a/fedml_config.yaml b/fedml_config.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..66714a20c7a2b471676dac8be0801b3ae132e5cf
--- /dev/null
+++ b/fedml_config.yaml
@@ -0,0 +1,49 @@
+common_args:
+  training_type: "simulation"
+  random_seed: 0
+  
+data_args:
+  dataset: "wget"
+  data_cache_dir: ~/fedgraphnn_data/
+  part_file:  ~/fedgraphnn_data/partition
+
+model_args:
+  model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
+  global_model_file_path: "./model_file_cache/global_model.pt"
+
+environment_args:
+  bootstrap: config/bootstrap.sh
+
+train_args:
+  federated_optimizer: "FedAvg"
+  client_id_list: 
+  client_num_in_total: 4
+  client_num_per_round: 4
+  comm_round: 100
+  lr: 0.001
+  server_lr: 0.001
+  wd: 0.001
+  ci: 0
+  server_momentum: 0.9
+
+validation_args:
+  frequency_of_the_test: 1
+
+device_args:
+  worker_num: 4
+  using_gpu: false
+  gpu_mapping_file: config/gpu_mapping.yaml
+  gpu_mapping_key: mapping_fedgraphnn_sp
+
+
+comm_args:
+  backend: "MPI"
+  is_mobile: 0
+
+
+tracking_args:
+   # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
+  enable_wandb: false
+  wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
+  wandb_project: fedml
+  wandb_name: fedml_torch_moleculenet
diff --git a/main.py b/main.py
new file mode 100755
index 0000000000000000000000000000000000000000..a1a716e17f69f73d33fa62ef5b3da1e85b90b715
--- /dev/null
+++ b/main.py
@@ -0,0 +1,88 @@
+import logging
+
+import fedml
+from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split
+from fedml import FedMLRunner
+from trainer.magic_trainer import MagicTrainer
+from trainer.magic_aggregator import MagicWgetAggregator
+from model.autoencoder import build_model
+from utils.config import build_args
+from trainer.magic_trainer import MagicTrainer
+from trainer.magic_aggregator import MagicWgetAggregator
+from trainer.single_trainer import train_single
+from utils.loaddata import load_batch_level_dataset, load_metadata
+
+
+
+def generate_dataset(name, number):
+    (
+        train_data_num,
+        val_data_num,
+        test_data_num,
+        train_data_global,
+        val_data_global,
+        test_data_global,
+        data_local_num_dict,
+        train_data_local_dict,
+        val_data_local_dict,
+        test_data_local_dict,
+    ) = load_partition_data(None, number, name) 
+    dataset = [
+        train_data_num,
+        test_data_num,
+        train_data_global,
+        test_data_global,
+        data_local_num_dict,
+        train_data_local_dict,
+        test_data_local_dict,
+        len(train_data_global),
+    ]
+    
+    if (name == "wget" or name == "streamspot"):
+        
+        return dataset, load_batch_level_dataset(name)
+    else:
+        return dataset, load_metadata(name) 
+           
+    
+if __name__ == "__main__":
+    # init FedML framework
+    args = fedml.init()
+    # init device
+    device = fedml.device.get_device(args)
+    name = args.dataset
+    number = args.client_num_in_total
+    
+    dataset, metadata = generate_dataset(name, number)
+    main_args = build_args()
+    if (name == "wget"):
+        main_args["num_hidden"] = 256
+        main_args["max_epoch"] = 2
+        main_args["num_layers"] = 4
+        n_node_feat = metadata['n_feat']
+        n_edge_feat = metadata['e_feat']
+        main_args["n_dim"] = n_node_feat
+        main_args["e_dim"] = n_edge_feat
+    elif (name == "streamspot"):
+        main_args["num_hidden"] = 256
+        main_args["max_epoch"] = 5
+        main_args["num_layers"] = 4
+        n_node_feat = metadata['n_feat']
+        n_edge_feat = metadata['e_feat']
+        main_args["n_dim"] = n_node_feat
+        main_args["e_dim"] = n_edge_feat
+    else:
+        main_args["num_hidden"] = 64
+        main_args["max_epoch"] = 50
+        main_args["num_layers"] = 3
+        main_args["n_dim"] = metadata["node_feature_dim"]
+        main_args["e_dim"] = metadata["edge_feature_dim"]
+    
+    model = build_model(main_args)
+    #train_single(main_args, model, data)
+    trainer = MagicTrainer(model, args, name)
+    aggregator = MagicWgetAggregator(model, args, name)
+    fedml_runner = FedMLRunner(args, device, dataset, model, trainer, aggregator)
+    fedml_runner.run()
+    # start training
+    #darpa_split("theia")
diff --git a/model/__pycache__/autoencoder.cpython-311.pyc b/model/__pycache__/autoencoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13b41e99adb07cc4cfa61dd2771b34854aa88d80
Binary files /dev/null and b/model/__pycache__/autoencoder.cpython-311.pyc differ
diff --git a/model/__pycache__/eval.cpython-311.pyc b/model/__pycache__/eval.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..c7d0bb5b6623e0d409dc28636e2722f423c0c7f0
Binary files /dev/null and b/model/__pycache__/eval.cpython-311.pyc differ
diff --git a/model/__pycache__/gat.cpython-311.pyc b/model/__pycache__/gat.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..0c9997f570d123655a70df66a86c3003e82f25ab
Binary files /dev/null and b/model/__pycache__/gat.cpython-311.pyc differ
diff --git a/model/__pycache__/loss_func.cpython-311.pyc b/model/__pycache__/loss_func.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..620a6bee7cb3ab8dc89bbc3448fdac41abb5f5ea
Binary files /dev/null and b/model/__pycache__/loss_func.cpython-311.pyc differ
diff --git a/model/__pycache__/train.cpython-311.pyc b/model/__pycache__/train.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..c202e2105beb34aa9db7a6eddeffd4cef4b25fa3
Binary files /dev/null and b/model/__pycache__/train.cpython-311.pyc differ
diff --git a/model/autoencoder.py b/model/autoencoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..da51f590f02076ad43a0f8f40037ae4ad5b4c761
--- /dev/null
+++ b/model/autoencoder.py
@@ -0,0 +1,179 @@
+from .gat import GAT
+from utils.utils import create_norm
+from functools import partial
+from itertools import chain
+from .loss_func import sce_loss
+import torch
+import torch.nn as nn
+import dgl
+import random
+
+
+def build_model(args):
+    num_hidden = args["num_hidden"]
+    num_layers = args["num_layers"]
+    negative_slope = args["negative_slope"]
+    mask_rate = args["mask_rate"]
+    alpha_l = args["alpha_l"]
+    n_dim = args["n_dim"]
+    e_dim = args["e_dim"]
+
+    model = GMAEModel(
+        n_dim=n_dim,
+        e_dim=e_dim,
+        hidden_dim=num_hidden,
+        n_layers=num_layers,
+        n_heads=4,
+        activation="prelu",
+        feat_drop=0.1,
+        negative_slope=negative_slope,
+        residual=True,
+        mask_rate=mask_rate,
+        norm='BatchNorm',
+        loss_fn='sce',
+        alpha_l=alpha_l
+    )
+    return model
+
+
+class GMAEModel(nn.Module):
+    def __init__(self, n_dim, e_dim, hidden_dim, n_layers, n_heads, activation,
+                 feat_drop, negative_slope, residual, norm, mask_rate=0.5, loss_fn="sce", alpha_l=2):
+        super(GMAEModel, self).__init__()
+        self._mask_rate = mask_rate
+        self._output_hidden_size = hidden_dim
+        self.recon_loss = nn.BCELoss(reduction='mean')
+
+        def init_weights(m):
+            if isinstance(m, nn.Linear):
+                nn.init.xavier_uniform(m.weight)
+                nn.init.constant_(m.bias, 0)
+
+        self.edge_recon_fc = nn.Sequential(
+            nn.Linear(hidden_dim * n_layers * 2, hidden_dim),
+            nn.LeakyReLU(negative_slope),
+            nn.Linear(hidden_dim, 1),
+            nn.Sigmoid()
+        )
+        self.edge_recon_fc.apply(init_weights)
+
+        assert hidden_dim % n_heads == 0
+        enc_num_hidden = hidden_dim // n_heads
+        enc_nhead = n_heads
+
+        dec_in_dim = hidden_dim
+        dec_num_hidden = hidden_dim
+
+        # build encoder
+        self.encoder = GAT(
+            n_dim=n_dim,
+            e_dim=e_dim,
+            hidden_dim=enc_num_hidden,
+            out_dim=enc_num_hidden,
+            n_layers=n_layers,
+            n_heads=enc_nhead,
+            n_heads_out=enc_nhead,
+            concat_out=True,
+            activation=activation,
+            feat_drop=feat_drop,
+            attn_drop=0.0,
+            negative_slope=negative_slope,
+            residual=residual,
+            norm=create_norm(norm),
+            encoding=True,
+        )
+
+        # build decoder for attribute prediction
+        self.decoder = GAT(
+            n_dim=dec_in_dim,
+            e_dim=e_dim,
+            hidden_dim=dec_num_hidden,
+            out_dim=n_dim,
+            n_layers=1,
+            n_heads=n_heads,
+            n_heads_out=1,
+            concat_out=True,
+            activation=activation,
+            feat_drop=feat_drop,
+            attn_drop=0.0,
+            negative_slope=negative_slope,
+            residual=residual,
+            norm=create_norm(norm),
+            encoding=False,
+        )
+
+        self.enc_mask_token = nn.Parameter(torch.zeros(1, n_dim))
+        self.encoder_to_decoder = nn.Linear(dec_in_dim * n_layers, dec_in_dim, bias=False)
+
+        # * setup loss function
+        self.criterion = self.setup_loss_fn(loss_fn, alpha_l)
+
+    @property
+    def output_hidden_dim(self):
+        return self._output_hidden_size
+
+    def setup_loss_fn(self, loss_fn, alpha_l):
+        if loss_fn == "sce":
+            criterion = partial(sce_loss, alpha=alpha_l)
+        else:
+            raise NotImplementedError
+        return criterion
+
+    def encoding_mask_noise(self, g, mask_rate=0.3):
+        new_g = g.clone()
+        num_nodes = g.num_nodes()
+        perm = torch.randperm(num_nodes, device=g.device)
+
+        # random masking
+        num_mask_nodes = int(mask_rate * num_nodes)
+        mask_nodes = perm[: num_mask_nodes]
+        keep_nodes = perm[num_mask_nodes:]
+
+        new_g.ndata["attr"][mask_nodes] = self.enc_mask_token
+
+        return new_g, (mask_nodes, keep_nodes)
+
+    def forward(self, g):
+        loss = self.compute_loss(g)
+        return loss
+
+    def compute_loss(self, g):
+        # Feature Reconstruction
+        pre_use_g, (mask_nodes, keep_nodes) = self.encoding_mask_noise(g, self._mask_rate)
+        pre_use_x = pre_use_g.ndata['attr'].to(pre_use_g.device)
+        use_g = pre_use_g
+        enc_rep, all_hidden = self.encoder(use_g, pre_use_x, return_hidden=True)
+        enc_rep = torch.cat(all_hidden, dim=1)
+        rep = self.encoder_to_decoder(enc_rep)
+
+        recon = self.decoder(pre_use_g, rep)
+        x_init = g.ndata['attr'][mask_nodes]
+        x_rec = recon[mask_nodes]
+        loss = self.criterion(x_rec, x_init)
+
+        # Structural Reconstruction
+        threshold = min(10000, g.num_nodes())
+
+        negative_edge_pairs = dgl.sampling.global_uniform_negative_sampling(g, threshold)
+        positive_edge_pairs = random.sample(range(g.number_of_edges()), threshold)
+        positive_edge_pairs = (g.edges()[0][positive_edge_pairs], g.edges()[1][positive_edge_pairs])
+        sample_src = enc_rep[torch.cat([positive_edge_pairs[0], negative_edge_pairs[0]])].to(g.device)
+        sample_dst = enc_rep[torch.cat([positive_edge_pairs[1], negative_edge_pairs[1]])].to(g.device)
+        y_pred = self.edge_recon_fc(torch.cat([sample_src, sample_dst], dim=-1)).squeeze(-1)
+        y = torch.cat([torch.ones(len(positive_edge_pairs[0])), torch.zeros(len(negative_edge_pairs[0]))]).to(
+            g.device)
+        loss += self.recon_loss(y_pred, y)
+        return loss
+
+    def embed(self, g):
+        x = g.ndata['attr'].to(g.device)
+        rep = self.encoder(g, x)
+        return rep
+
+    @property
+    def enc_params(self):
+        return self.encoder.parameters()
+
+    @property
+    def dec_params(self):
+        return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])
diff --git a/model/eval.py b/model/eval.py
new file mode 100755
index 0000000000000000000000000000000000000000..ff777aeae5f9a8664fe89c384334eb93d02c5f44
--- /dev/null
+++ b/model/eval.py
@@ -0,0 +1,241 @@
+import os
+import random
+import time
+import pickle as pkl
+import torch
+import numpy as np
+from sklearn.metrics import roc_auc_score, precision_recall_curve
+from sklearn.neighbors import NearestNeighbors
+from utils.utils import set_random_seed
+from data.data_loader import load_batch_level_dataset_main
+from utils.loaddata import transform_graph, load_batch_level_dataset
+
+
+def batch_level_evaluation(model, pooler, device, method, dataset, n_dim=0, e_dim=0):
+    model.eval()
+    x_list = []
+    y_list = []
+    data = load_batch_level_dataset(dataset)
+    full = data['full_index']
+    graphs = data['dataset']
+    with torch.no_grad():
+        for i in full:
+            g = transform_graph(graphs[i][0], n_dim, e_dim).to(device)
+            label = graphs[i][1]
+            out = model.embed(g)
+            if dataset != 'wget':
+                out = pooler(g, out).cpu().numpy()
+            else:
+                out = pooler(g, out, [2]).cpu().numpy()
+            y_list.append(label)
+            x_list.append(out)
+    x = np.concatenate(x_list, axis=0)
+    y = np.array(y_list)
+    if 'knn' in method:
+        test_auc, test_std = evaluate_batch_level_using_knn(-1, dataset, x, y)
+    else:
+        raise NotImplementedError
+    return test_auc, test_std
+
+
+def evaluate_batch_level_using_knn(repeat, dataset, embeddings, labels):
+    x, y = embeddings, labels
+    if dataset == 'streamspot':
+        train_count = 400
+    else:
+        train_count = 100
+    n_neighbors = min(int(train_count * 0.02), 10)
+    benign_idx = np.where(y == 0)[0]
+    attack_idx = np.where(y == 1)[0]
+    if repeat != -1:
+        prec_list = []
+        rec_list = []
+        f1_list = []
+        tp_list = []
+        fp_list = []
+        tn_list = []
+        fn_list = []
+        auc_list = []
+        for s in range(repeat):
+            set_random_seed(s)
+            np.random.shuffle(benign_idx)
+            np.random.shuffle(attack_idx)
+            x_train = x[benign_idx[:train_count]]
+            x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0)
+            y_test = np.concatenate([y[benign_idx[train_count:]], y[attack_idx]], axis=0)
+            x_train_mean = x_train.mean(axis=0)
+            x_train_std = x_train.std(axis=0)
+            x_train = (x_train - x_train_mean) / x_train_std
+            x_test = (x_test - x_train_mean) / x_train_std
+
+            nbrs = NearestNeighbors(n_neighbors=n_neighbors)
+            nbrs.fit(x_train)
+            distances, indexes = nbrs.kneighbors(x_train, n_neighbors=n_neighbors)
+            mean_distance = distances.mean() * n_neighbors / (n_neighbors - 1)
+            distances, indexes = nbrs.kneighbors(x_test, n_neighbors=n_neighbors)
+
+            score = distances.mean(axis=1) / mean_distance
+
+            auc = roc_auc_score(y_test, score)
+            prec, rec, threshold = precision_recall_curve(y_test, score)
+            f1 = 2 * prec * rec / (rec + prec + 1e-9)
+            max_f1_idx = np.argmax(f1)
+            best_thres = threshold[max_f1_idx]
+            prec_list.append(prec[max_f1_idx])
+            rec_list.append(rec[max_f1_idx])
+            f1_list.append(f1[max_f1_idx])
+
+            tn = 0
+            fn = 0
+            tp = 0
+            fp = 0
+            for i in range(len(y_test)):
+                if y_test[i] == 1.0 and score[i] >= best_thres:
+                    tp += 1
+                if y_test[i] == 1.0 and score[i] < best_thres:
+                    fn += 1
+                if y_test[i] == 0.0 and score[i] < best_thres:
+                    tn += 1
+                if y_test[i] == 0.0 and score[i] >= best_thres:
+                    fp += 1
+            tp_list.append(tp)
+            fp_list.append(fp)
+            fn_list.append(fn)
+            tn_list.append(tn)
+            auc_list.append(auc)
+
+        print('AUC: {}+{}'.format(np.mean(auc_list), np.std(auc_list)))
+        print('F1: {}+{}'.format(np.mean(f1_list), np.std(f1_list)))
+        print('PRECISION: {}+{}'.format(np.mean(prec_list), np.std(prec_list)))
+        print('RECALL: {}+{}'.format(np.mean(rec_list), np.std(rec_list)))
+        print('TN: {}+{}'.format(np.mean(tn_list), np.std(tn_list)))
+        print('FN: {}+{}'.format(np.mean(fn_list), np.std(fn_list)))
+        print('TP: {}+{}'.format(np.mean(tp_list), np.std(tp_list)))
+        print('FP: {}+{}'.format(np.mean(fp_list), np.std(fp_list)))
+        return np.mean(auc_list), np.std(auc_list)
+    else:
+        set_random_seed(2022)
+        np.random.shuffle(benign_idx)
+        np.random.shuffle(attack_idx)
+        x_train = x[benign_idx[:train_count]]
+        x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0)
+        y_test = np.concatenate([y[benign_idx[train_count:]], y[attack_idx]], axis=0)
+        x_train_mean = x_train.mean(axis=0)
+        x_train_std = x_train.std(axis=0)
+        for i in range(len(x_train_std)):
+           if (x_train_std[i] == 0 ):
+               x_train_std[i] = 0.000000000000001
+        x_train = (x_train - x_train_mean) / x_train_std
+        x_test = (x_test - x_train_mean) / x_train_std
+
+        nbrs = NearestNeighbors(n_neighbors=n_neighbors)
+        nbrs.fit(x_train)
+        distances, indexes = nbrs.kneighbors(x_train, n_neighbors=n_neighbors)
+        mean_distance = distances.mean() * n_neighbors / (n_neighbors - 1)
+        distances, indexes = nbrs.kneighbors(x_test, n_neighbors=n_neighbors)
+
+        score = distances.mean(axis=1) / mean_distance
+        auc = roc_auc_score(y_test, score)
+        prec, rec, threshold = precision_recall_curve(y_test, score)
+        f1 = 2 * prec * rec / (rec + prec + 1e-9)
+        best_idx = np.argmax(f1)
+        best_thres = threshold[best_idx]
+
+        tn = 0
+        fn = 0
+        tp = 0
+        fp = 0
+        for i in range(len(y_test)):
+            if y_test[i] == 1.0 and score[i] >= best_thres:
+                tp += 1
+            if y_test[i] == 1.0 and score[i] < best_thres:
+                fn += 1
+            if y_test[i] == 0.0 and score[i] < best_thres:
+                tn += 1
+            if y_test[i] == 0.0 and score[i] >= best_thres:
+                fp += 1
+        print('AUC: {}'.format(auc))
+        print('F1: {}'.format(f1[best_idx]))
+        print('PRECISION: {}'.format(prec[best_idx]))
+        print('RECALL: {}'.format(rec[best_idx]))
+        print('TN: {}'.format(tn))
+        print('FN: {}'.format(fn))
+        print('TP: {}'.format(tp))
+        print('FP: {}'.format(fp))
+        return auc, 0.0
+
+
+def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test):
+    x_train_mean = x_train.mean(axis=0)
+    x_train_std = x_train.std(axis=0)
+    x_train = (x_train - x_train_mean) / x_train_std
+    x_test = (x_test - x_train_mean) / x_train_std
+
+    if dataset == 'cadets':
+        n_neighbors = 200
+    else:
+        n_neighbors = 10
+    nbrs = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=-1)
+    nbrs.fit(x_train)
+
+    save_dict_path = './eval_result/distance_save_{}.pkl'.format(dataset)
+    if not os.path.exists(save_dict_path):
+        idx = list(range(x_train.shape[0]))
+        random.shuffle(x_train[idx][:min(50000, x_train.shape[0])])
+        print(len(x_train[idx][:min(50000, x_train.shape[0])]))
+        distances, _ = nbrs.kneighbors(x_train[idx][:min(50000, x_train.shape[0])], n_neighbors=n_neighbors)
+        del x_train
+        mean_distance = distances.mean()
+        del distances
+        print('here')
+        print(len(x_test))
+        distances, _ = nbrs.kneighbors(x_test, n_neighbors=n_neighbors)
+        print("yes")
+        save_dict = [mean_distance, distances.mean(axis=1)]
+        distances = distances.mean(axis=1)
+        with open(save_dict_path, 'wb') as f:
+            pkl.dump(save_dict, f)
+    else:
+        with open(save_dict_path, 'rb') as f:
+            mean_distance, distances = pkl.load(f)
+    score = distances / mean_distance
+    del distances
+    auc = roc_auc_score(y_test, score)
+    prec, rec, threshold = precision_recall_curve(y_test, score)
+    f1 = 2 * prec * rec / (rec + prec + 1e-9)
+    best_idx = -1
+    for i in range(len(f1)):
+        # To repeat peak performance
+        if dataset == 'trace' and rec[i] < 0.99979:
+            best_idx = i - 1
+            break
+        if dataset == 'theia' and rec[i] < 0.99996:
+            best_idx = i - 1
+            break
+        if dataset == 'cadets' and rec[i] < 0.9976:
+            best_idx = i - 1
+            break
+    best_thres = threshold[best_idx]
+
+    tn = 0
+    fn = 0
+    tp = 0
+    fp = 0
+    for i in range(len(y_test)):
+        if y_test[i] == 1.0 and score[i] >= best_thres:
+            tn += 1
+        if y_test[i] == 1.0 and score[i] < best_thres:
+            fn += 1
+        if y_test[i] == 0.0 and score[i] < best_thres:
+            tp += 1
+        if y_test[i] == 0.0 and score[i] >= best_thres:
+            fp += 1
+    print('AUC: {}'.format(auc))
+    print('F1: {}'.format(f1[best_idx]))
+    print('PRECISION: {}'.format(prec[best_idx]))
+    print('RECALL: {}'.format(rec[best_idx]))
+    print('TN: {}'.format(tn))
+    print('FN: {}'.format(fn))
+    print('TP: {}'.format(tp))
+    print('FP: {}'.format(fp))
+    return auc, 0.0, None, None
diff --git a/model/gat.py b/model/gat.py
new file mode 100755
index 0000000000000000000000000000000000000000..c64054f0dc985f37d6c1588afc5793eaf36417a6
--- /dev/null
+++ b/model/gat.py
@@ -0,0 +1,233 @@
+import torch
+import torch.nn as nn
+from dgl.ops import edge_softmax
+import dgl.function as fn
+from dgl.utils import expand_as_pair
+from utils.utils import create_activation
+
+
+class GAT(nn.Module):
+    def __init__(self,
+                 n_dim,
+                 e_dim,
+                 hidden_dim,
+                 out_dim,
+                 n_layers,
+                 n_heads,
+                 n_heads_out,
+                 activation,
+                 feat_drop,
+                 attn_drop,
+                 negative_slope,
+                 residual,
+                 norm,
+                 concat_out=False,
+                 encoding=False
+                 ):
+        super(GAT, self).__init__()
+        self.out_dim = out_dim
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.gats = nn.ModuleList()
+        self.concat_out = concat_out
+
+        last_activation = create_activation(activation) if encoding else None
+        last_residual = (encoding and residual)
+        last_norm = norm if encoding else None
+
+        if self.n_layers == 1:
+            self.gats.append(GATConv(
+                n_dim, e_dim, out_dim, n_heads_out, feat_drop, attn_drop, negative_slope,
+                last_residual, norm=last_norm, concat_out=self.concat_out
+            ))
+        else:
+            self.gats.append(GATConv(
+                n_dim, e_dim, hidden_dim, n_heads, feat_drop, attn_drop, negative_slope,
+                residual, create_activation(activation),
+                norm=norm, concat_out=self.concat_out
+            ))
+            for _ in range(1, self.n_layers - 1):
+                self.gats.append(GATConv(
+                    hidden_dim * self.n_heads, e_dim, hidden_dim, n_heads,
+                    feat_drop, attn_drop, negative_slope,
+                    residual, create_activation(activation),
+                    norm=norm, concat_out=self.concat_out
+                ))
+            self.gats.append(GATConv(
+                hidden_dim * self.n_heads, e_dim, out_dim, n_heads_out,
+                feat_drop, attn_drop, negative_slope,
+                last_residual, last_activation, norm=last_norm, concat_out=self.concat_out
+            ))
+        self.head = nn.Identity()
+
+    def forward(self, g, input_feature, return_hidden=False):
+        h = input_feature
+        hidden_list = []
+        for layer in range(self.n_layers):
+            h = self.gats[layer](g, h)
+            hidden_list.append(h)
+        if return_hidden:
+            return self.head(h), hidden_list
+        else:
+            return self.head(h)
+
+    def reset_classifier(self, num_classes):
+        self.head = nn.Linear(self.num_heads * self.out_dim, num_classes)
+
+
+class GATConv(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 e_dim,
+                 out_dim,
+                 n_heads,
+                 feat_drop=0.0,
+                 attn_drop=0.0,
+                 negative_slope=0.2,
+                 residual=False,
+                 activation=None,
+                 allow_zero_in_degree=False,
+                 bias=True,
+                 norm=None,
+                 concat_out=True):
+        super(GATConv, self).__init__()
+        self.n_heads = n_heads
+        self.src_feat, self.dst_feat = expand_as_pair(in_dim)
+        self.edge_feat = e_dim
+        self.out_feat = out_dim
+        self.allow_zero_in_degree = allow_zero_in_degree
+        self.concat_out = concat_out
+
+        if isinstance(in_dim, tuple):
+            self.fc_node_embedding = nn.Linear(
+                self.src_feat, self.out_feat * self.n_heads, bias=False)
+            self.fc_src = nn.Linear(self.src_feat, self.out_feat * self.n_heads, bias=False)
+            self.fc_dst = nn.Linear(self.dst_feat, self.out_feat * self.n_heads, bias=False)
+        else:
+            self.fc_node_embedding = nn.Linear(
+                self.src_feat, self.out_feat * self.n_heads, bias=False)
+            self.fc = nn.Linear(self.src_feat, self.out_feat * self.n_heads, bias=False)
+        self.edge_fc = nn.Linear(self.edge_feat, self.out_feat * self.n_heads, bias=False)
+        self.attn_h = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat)))
+        self.attn_e = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat)))
+        self.attn_t = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat)))
+        self.feat_drop = nn.Dropout(feat_drop)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.leaky_relu = nn.LeakyReLU(negative_slope)
+        if bias:
+            self.bias = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat)))
+        else:
+            self.register_buffer('bias', None)
+        if residual:
+            if self.dst_feat != self.n_heads * self.out_feat:
+                self.res_fc = nn.Linear(
+                    self.dst_feat, self.n_heads * self.out_feat, bias=False)
+            else:
+                self.res_fc = nn.Identity()
+        else:
+            self.register_buffer('res_fc', None)
+        self.reset_parameters()
+        self.activation = activation
+        self.norm = norm
+        if norm is not None:
+            self.norm = norm(self.n_heads * self.out_feat)
+
+    def reset_parameters(self):
+        gain = nn.init.calculate_gain('relu')
+        nn.init.xavier_normal_(self.edge_fc.weight, gain=gain)
+        if hasattr(self, 'fc'):
+            nn.init.xavier_normal_(self.fc.weight, gain=gain)
+        else:
+            nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
+            nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
+        nn.init.xavier_normal_(self.attn_h, gain=gain)
+        nn.init.xavier_normal_(self.attn_e, gain=gain)
+        nn.init.xavier_normal_(self.attn_t, gain=gain)
+        if self.bias is not None:
+            nn.init.constant_(self.bias, 0)
+        if isinstance(self.res_fc, nn.Linear):
+            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
+
+    def set_allow_zero_in_degree(self, set_value):
+        self.allow_zero_in_degree = set_value
+
+    def forward(self, graph, feat, get_attention=False):
+        edge_feature = graph.edata['attr']
+        with graph.local_scope():
+            if isinstance(feat, tuple):
+                src_prefix_shape = feat[0].shape[:-1]
+                dst_prefix_shape = feat[1].shape[:-1]
+                h_src = self.feat_drop(feat[0])
+                h_dst = self.feat_drop(feat[1])
+                if not hasattr(self, 'fc_src'):
+                    feat_src = self.fc(h_src).view(
+                        *src_prefix_shape, self.n_heads, self.out_feat)
+                    feat_dst = self.fc(h_dst).view(
+                        *dst_prefix_shape, self.n_heads, self.out_feat)
+                else:
+                    feat_src = self.fc_src(h_src).view(
+                        *src_prefix_shape, self.n_heads, self.out_feat)
+                    feat_dst = self.fc_dst(h_dst).view(
+                        *dst_prefix_shape, self.n_heads, self.out_feat)
+            else:
+                src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
+                h_src = h_dst = self.feat_drop(feat)
+                feat_src = feat_dst = self.fc(h_src).view(
+                    *src_prefix_shape, self.n_heads, self.out_feat)
+                if graph.is_block:
+                    feat_dst = feat_src[:graph.number_of_dst_nodes()]
+                    h_dst = h_dst[:graph.number_of_dst_nodes()]
+                    dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
+            edge_prefix_shape = edge_feature.shape[:-1]
+            eh = (feat_src * self.attn_h).sum(-1).unsqueeze(-1)
+            et = (feat_dst * self.attn_t).sum(-1).unsqueeze(-1)
+
+            graph.srcdata.update({'hs': feat_src, 'eh': eh})
+            graph.dstdata.update({'et': et})
+
+            feat_edge = self.edge_fc(edge_feature).view(
+                *edge_prefix_shape, self.n_heads, self.out_feat)
+            ee = (feat_edge * self.attn_e).sum(-1).unsqueeze(-1)
+
+            graph.edata.update({'ee': ee})
+            graph.apply_edges(fn.u_add_e('eh', 'ee', 'ee'))
+            graph.apply_edges(fn.e_add_v('ee', 'et', 'e'))
+            """
+            graph.apply_edges(fn.u_add_v('eh', 'et', 'e'))
+            """
+            e = self.leaky_relu(graph.edata.pop('e'))
+            graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
+            # message passing
+
+            graph.update_all(fn.u_mul_e('hs', 'a', 'm'),
+                             fn.sum('m', 'hs'))
+
+            rst = graph.dstdata['hs'].view(-1, self.n_heads, self.out_feat)
+
+            if self.bias is not None:
+                rst = rst + self.bias.view(
+                    *((1,) * len(dst_prefix_shape)), self.n_heads, self.out_feat)
+
+            # residual
+
+            if self.res_fc is not None:
+                # Use -1 rather than self._num_heads to handle broadcasting
+                resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self.out_feat)
+                rst = rst + resval
+
+            if self.concat_out:
+                rst = rst.flatten(1)
+            else:
+                rst = torch.mean(rst, dim=1)
+
+            if self.norm is not None:
+                rst = self.norm(rst)
+
+                # activation
+            if self.activation:
+                rst = self.activation(rst)
+
+            if get_attention:
+                return rst, graph.edata['a']
+            else:
+                return rst
diff --git a/model/loss_func.py b/model/loss_func.py
new file mode 100755
index 0000000000000000000000000000000000000000..a2eff3e79376bf6ec4157b3c56f3a6ebd50a3127
--- /dev/null
+++ b/model/loss_func.py
@@ -0,0 +1,9 @@
+import torch.nn.functional as F
+
+
+def sce_loss(x, y, alpha=3):
+    x = F.normalize(x, p=2, dim=-1)
+    y = F.normalize(y, p=2, dim=-1)
+    loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
+    loss = loss.mean()
+    return loss
diff --git a/model/mlp.py b/model/mlp.py
new file mode 100755
index 0000000000000000000000000000000000000000..e1ee9581a033acdbec3b4501cdd45c2d72e7efd9
--- /dev/null
+++ b/model/mlp.py
@@ -0,0 +1,13 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MLP(nn.Module):
+    def __init__(self, d_model, d_ff, dropout=0.1):
+        super(MLP, self).__init__()
+        self.w_1 = nn.Linear(d_model, d_ff)
+        self.w_2 = nn.Linear(d_ff, d_model)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, x):
+        return self.w_2(self.dropout(F.relu(self.w_1(x))))
diff --git a/model/train.py b/model/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..9ec20ff2f7ec83a695904dff381cb10191d72bfc
--- /dev/null
+++ b/model/train.py
@@ -0,0 +1,28 @@
+import dgl
+import numpy as np
+from tqdm import tqdm
+from utils.loaddata import transform_graph
+
+
+def batch_level_train(model, graphs, train_loader, optimizer, max_epoch, device, n_dim=0, e_dim=0):
+    epoch_iter = tqdm(range(max_epoch))
+    cpt = 1
+    for epoch in epoch_iter:
+        model.train()
+        loss_list = []
+        cpt = 0
+        for _, batch in enumerate(train_loader):
+            cpt +=1
+            print(cpt)
+            batch_g = [transform_graph(graphs[idx][0], n_dim, e_dim).to(device) for idx in batch]
+            batch_g = dgl.batch(batch_g)
+            model.train()
+            loss = model(batch_g)
+
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+            loss_list.append(loss.item())
+            del batch_g
+        epoch_iter.set_description(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f}")
+    return model
diff --git a/mpi_host_file b/mpi_host_file
new file mode 100755
index 0000000000000000000000000000000000000000..ee700f9141190f2dace3d92aa89e1d95288527f0
--- /dev/null
+++ b/mpi_host_file
@@ -0,0 +1 @@
+kamel-virtual-machine
diff --git a/requirement.txt b/requirement.txt
new file mode 100755
index 0000000000000000000000000000000000000000..f4f68bfb59c5119dc0c32cbb3c1c1f1140b925a6
--- /dev/null
+++ b/requirement.txt
@@ -0,0 +1,133 @@
+aiohttp==3.9.1
+aiosignal==1.3.1
+anyio==4.2.0
+attrdict==2.0.1
+attrs==23.2.0
+blis==0.7.11
+boto3==1.34.12
+botocore==1.34.12
+Brotli==1.1.0
+catalogue==2.0.10
+certifi==2023.11.17
+chardet==5.2.0
+charset-normalizer==3.3.2
+click==8.1.7
+cloudpathlib==0.16.0
+confection==0.1.4
+contourpy==1.2.0
+cycler==0.12.1
+cymem==2.0.8
+dgl==1.1.3
+dill==0.3.7
+docker==6.1.3
+docker-pycreds==0.4.0
+fastapi==0.92.0
+fedml==0.8.13.post2
+filelock==3.13.1
+fonttools==4.47.0
+frozenlist==1.4.1
+fsspec==2023.12.2
+gensim==4.3.2
+gevent==23.9.1
+geventhttpclient==2.0.9
+gitdb==4.0.11
+GitPython==3.1.40
+GPUtil==1.4.0
+graphviz==0.8.4
+greenlet==3.0.3
+h11==0.14.0
+h5py==3.10.0
+httpcore==1.0.2
+httpx==0.26.0
+idna==3.6
+Jinja2==3.1.2
+jmespath==1.0.1
+joblib==1.3.2
+kiwisolver==1.4.5
+langcodes==3.3.0
+MarkupSafe==2.1.3
+matplotlib==3.8.2
+mpi4py @ file:///home/conda/feedstock_root/build_artifacts/mpi4py_1697529104664/work
+mpmath==1.3.0
+multidict==6.0.4
+multiprocess==0.70.15
+murmurhash==1.0.10
+networkx==2.8.8
+ntplib==0.4.0
+numpy==1.26.3
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==8.9.2.26
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.18.1
+nvidia-nvjitlink-cu12==12.3.101
+nvidia-nvtx-cu12==12.1.105
+onnx==1.15.0
+packaging==23.2
+paho-mqtt==1.6.1
+pandas==2.1.4
+pathtools==0.1.2
+pillow==10.2.0
+preshed==3.0.9
+prettytable==3.9.0
+promise==2.3
+protobuf==3.20.3
+psutil==5.9.7
+py-machineid==0.4.6
+pydantic==1.10.13
+pyparsing==3.1.1
+python-dateutil==2.8.2
+python-rapidjson==1.14
+pytz==2023.3.post1
+PyYAML==6.0.1
+redis==5.0.1
+requests==2.31.0
+s3transfer==0.10.0
+scikit-learn==1.3.2
+scipy==1.11.4
+sentry-sdk==1.39.1
+setproctitle==1.3.3
+shortuuid==1.0.11
+six==1.16.0
+smart-open==6.3.0
+smmap==5.0.1
+sniffio==1.3.0
+spacy==3.7.2
+spacy-legacy==3.0.12
+spacy-loggers==1.0.5
+SQLAlchemy==2.0.25
+srsly==2.4.8
+starlette==0.25.0
+sympy==1.12
+thinc==8.2.2
+threadpoolctl==3.2.0
+torch==2.1.2
+torch-cluster==1.6.3
+torch-scatter==2.1.2
+torch-sparse==0.6.18
+torch-spline-conv==1.2.2
+torch_geometric==2.4.0
+torchvision==0.16.2
+tqdm==4.66.1
+triton==2.1.0
+tritonclient==2.41.0
+typer==0.9.0
+typing_extensions==4.9.0
+tzdata==2023.4
+tzlocal==5.2
+urllib3==2.0.7
+uvicorn==0.25.0
+wandb==0.13.2
+wasabi==1.1.2
+wcwidth==0.2.12
+weasel==0.3.4
+websocket-client==1.7.0
+wget==3.2
+yarl==1.9.4
+zope.event==5.0
+zope.interface==6.1
diff --git a/result/FedAvg-2client-2round-cadets.pt b/result/FedAvg-2client-2round-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..97537e00b83117c06200ca53c0e22cef67b81fe8
Binary files /dev/null and b/result/FedAvg-2client-2round-cadets.pt differ
diff --git a/result/FedAvg-2client-cadets.pt b/result/FedAvg-2client-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..b9fb122af9072298b859056fa94be04932ef7a8b
Binary files /dev/null and b/result/FedAvg-2client-cadets.pt differ
diff --git a/result/FedAvg-4client-cadets.pt b/result/FedAvg-4client-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..4fb2cd9db49d52015b81ce437063abfa0c17d065
Binary files /dev/null and b/result/FedAvg-4client-cadets.pt differ
diff --git a/result/FedAvg-cadets.pt b/result/FedAvg-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..011e3d240dd0a104c0f4bc54541cb31c8cab5fcb
Binary files /dev/null and b/result/FedAvg-cadets.pt differ
diff --git a/result/FedAvg-theia.pt b/result/FedAvg-theia.pt
new file mode 100755
index 0000000000000000000000000000000000000000..66b64fb4c8b42b811c28d7c37238c88d5122dbd2
Binary files /dev/null and b/result/FedAvg-theia.pt differ
diff --git a/result/FedAvg_Streamspot-streamspot.pt b/result/FedAvg_Streamspot-streamspot.pt
new file mode 100755
index 0000000000000000000000000000000000000000..36015f47e0574d828dcc17178ebd5afbf8daaead
Binary files /dev/null and b/result/FedAvg_Streamspot-streamspot.pt differ
diff --git a/result/FedOpt-cadets.pt b/result/FedOpt-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..291831ac85f89a3b16648d005a814a064f1c2cdf
Binary files /dev/null and b/result/FedOpt-cadets.pt differ
diff --git a/result/FedOpt-theia.pt b/result/FedOpt-theia.pt
new file mode 100755
index 0000000000000000000000000000000000000000..383ebdb7544ff96e65dc408b857d5f61803b14e4
Binary files /dev/null and b/result/FedOpt-theia.pt differ
diff --git a/result/FedOpt-trace.pt b/result/FedOpt-trace.pt
new file mode 100644
index 0000000000000000000000000000000000000000..3e7faf31ca32212b8158216eaab0be77c2696ada
Binary files /dev/null and b/result/FedOpt-trace.pt differ
diff --git a/result/FedOpt_Streamspot.pt b/result/FedOpt_Streamspot.pt
new file mode 100755
index 0000000000000000000000000000000000000000..36015f47e0574d828dcc17178ebd5afbf8daaead
Binary files /dev/null and b/result/FedOpt_Streamspot.pt differ
diff --git a/result/FedProx-cadets.pt b/result/FedProx-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..deaf520dbef791e5e8520370443acffe65e8825b
Binary files /dev/null and b/result/FedProx-cadets.pt differ
diff --git a/result/FedProx-theia.pt b/result/FedProx-theia.pt
new file mode 100755
index 0000000000000000000000000000000000000000..100b9f885317fcc267d0fb44a53ca3928c908a4e
Binary files /dev/null and b/result/FedProx-theia.pt differ
diff --git a/result/FedProx-trace.pt b/result/FedProx-trace.pt
new file mode 100644
index 0000000000000000000000000000000000000000..b34e70d94c6e557d0c2f92efd1d2082076393c38
Binary files /dev/null and b/result/FedProx-trace.pt differ
diff --git a/result/FedProx_Streamspot.pt b/result/FedProx_Streamspot.pt
new file mode 100755
index 0000000000000000000000000000000000000000..232f436cc87c609d59174ee35b8bebc1fcda303b
Binary files /dev/null and b/result/FedProx_Streamspot.pt differ
diff --git a/result/checkpoint-trace.pt b/result/checkpoint-trace.pt
new file mode 100644
index 0000000000000000000000000000000000000000..880702e81259a697a9a794bb943927b2327b7420
Binary files /dev/null and b/result/checkpoint-trace.pt differ
diff --git a/save_results/distance_save_cadets_FedAvg.pkl b/save_results/distance_save_cadets_FedAvg.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..870244cd32749e587a2b411dc1c13dee18d08781
Binary files /dev/null and b/save_results/distance_save_cadets_FedAvg.pkl differ
diff --git a/save_results/distance_save_cadets_FedOpt.pkl b/save_results/distance_save_cadets_FedOpt.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..9bdff9888dbf1ee8f9fb45f255b54502b4b56d6c
Binary files /dev/null and b/save_results/distance_save_cadets_FedOpt.pkl differ
diff --git a/save_results/distance_save_cadets_FedProx.pkl b/save_results/distance_save_cadets_FedProx.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..8680ed4c84080861d938bc5ea620d9d8589fc39a
Binary files /dev/null and b/save_results/distance_save_cadets_FedProx.pkl differ
diff --git a/save_results/distance_save_theia_FedAvg.pkl b/save_results/distance_save_theia_FedAvg.pkl
new file mode 100755
index 0000000000000000000000000000000000000000..d885a972d533697733d1c953742df203a5c183a9
Binary files /dev/null and b/save_results/distance_save_theia_FedAvg.pkl differ
diff --git a/save_results/distance_save_theia_FedOpt.pkl b/save_results/distance_save_theia_FedOpt.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..3e84da9be1a793612b77573f16b80f4bba3f530b
Binary files /dev/null and b/save_results/distance_save_theia_FedOpt.pkl differ
diff --git a/save_results/distance_save_theia_FedProx.pkl b/save_results/distance_save_theia_FedProx.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..80e772d9f66424b552809cec77dae53a291f83e4
Binary files /dev/null and b/save_results/distance_save_theia_FedProx.pkl differ
diff --git a/save_results/distance_save_trace_FedAvg.pkl b/save_results/distance_save_trace_FedAvg.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..12baf82b2236f724c1c9f65a3adb19838a693c4b
Binary files /dev/null and b/save_results/distance_save_trace_FedAvg.pkl differ
diff --git a/save_results/distance_save_trace_FedOpt.pkl b/save_results/distance_save_trace_FedOpt.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..6d90af670837a994ba274a3a3a2d09ea5b941d31
Binary files /dev/null and b/save_results/distance_save_trace_FedOpt.pkl differ
diff --git a/save_results/distance_save_trace_FedProx.pkl b/save_results/distance_save_trace_FedProx.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..e88de86f4bcb805e0475a41d4de3bbfde5d7353e
Binary files /dev/null and b/save_results/distance_save_trace_FedProx.pkl differ
diff --git a/test.py b/test.py
new file mode 100755
index 0000000000000000000000000000000000000000..c1f6ebe728019b94bfcd9513907c28a9d39eaa81
--- /dev/null
+++ b/test.py
@@ -0,0 +1,15 @@
+from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split
+
+
+(
+        train_data_num,
+        val_data_num,
+        test_data_num,
+        train_data_global,
+        val_data_global,
+        test_data_global,
+        data_local_num_dict,
+        train_data_local_dict,
+        val_data_local_dict,
+        test_data_local_dict,
+    ) = load_partition_data(None, 3, "streamspot") 
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..968181a1cfaa6de944115c62b2b187376c734502
--- /dev/null
+++ b/train.py
@@ -0,0 +1,90 @@
+import os
+import random
+import torch
+import warnings
+from tqdm import tqdm
+from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata, transform_graph
+from model.autoencoder import build_model
+from torch.utils.data.sampler import SubsetRandomSampler
+from dgl.dataloading import GraphDataLoader
+import dgl
+from model.train import batch_level_train
+from utils.utils import set_random_seed, create_optimizer
+from utils.config import build_args
+warnings.filterwarnings('ignore')
+
+
+def extract_dataloaders(entries, batch_size):
+    random.shuffle(entries)
+    train_idx = torch.arange(len(entries))
+    train_sampler = SubsetRandomSampler(train_idx)
+    train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler)
+    return train_loader
+
+
+def main(main_args):
+    device = "cpu"
+    dataset_name =  "trace"
+    if dataset_name == 'streamspot':
+        main_args.num_hidden = 256
+        main_args.max_epoch = 5
+        main_args.num_layers = 4
+    elif dataset_name == 'wget':
+        main_args.num_hidden = 256
+        main_args.max_epoch = 2
+        main_args.num_layers = 4
+    else:
+        main_args["num_hidden"] = 64
+        main_args["max_epoch"] = 50
+        main_args["num_layers"] = 3
+    set_random_seed(0)
+
+    if dataset_name == 'streamspot' or dataset_name == 'wget':
+        if dataset_name == 'streamspot':
+            batch_size = 12
+        else:
+            batch_size = 1
+        dataset = load_batch_level_dataset(dataset_name)
+        n_node_feat = dataset['n_feat']
+        n_edge_feat = dataset['e_feat']
+        graphs = dataset['dataset']
+        train_index = dataset['train_index']
+        main_args.n_dim = n_node_feat
+        main_args.e_dim = n_edge_feat
+        model = build_model(main_args)
+        model = model.to(device)
+        optimizer = create_optimizer(main_args.optimizer, model, main_args.lr, main_args.weight_decay)
+        model = batch_level_train(model, graphs, (extract_dataloaders(train_index, batch_size)),
+                                  optimizer, main_args.max_epoch, device, main_args.n_dim, main_args.e_dim)
+        torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name))
+    else:
+        metadata = load_metadata(dataset_name)
+        main_args["n_dim"] = metadata['node_feature_dim']
+        main_args["e_dim"] = metadata['edge_feature_dim']
+        model = build_model(main_args)
+        model = model.to(device)
+        model.train()
+        optimizer = create_optimizer(main_args["optimizer"], model, main_args["lr"], main_args["weight_decay"])
+        epoch_iter = tqdm(range(main_args["max_epoch"]))
+        n_train = metadata['n_train']
+        for epoch in epoch_iter:
+            epoch_loss = 0.0
+            for i in range(n_train):
+                g = load_entity_level_dataset(dataset_name, 'train', i).to(device)
+                model.train()
+                loss  = model(g)
+                loss /= n_train
+                optimizer.zero_grad()
+                epoch_loss += loss.item()
+                loss.backward()
+                optimizer.step()
+                del g
+            epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
+            torch.save(model.state_dict(), "./result/checkpoint-{}.pt".format(dataset_name))
+
+    return
+
+
+if __name__ == '__main__':
+    args = build_args()
+    main(args)
diff --git a/trainer/__pycache__/magic_aggregator.cpython-311.pyc b/trainer/__pycache__/magic_aggregator.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9f8b782ee6ef9465d777c9e882fb180828f07ea
Binary files /dev/null and b/trainer/__pycache__/magic_aggregator.cpython-311.pyc differ
diff --git a/trainer/__pycache__/magic_trainer.cpython-311.pyc b/trainer/__pycache__/magic_trainer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a9ecabf7f8cc81c7ccd9f920ba01e5e3c9436ee9
Binary files /dev/null and b/trainer/__pycache__/magic_trainer.cpython-311.pyc differ
diff --git a/trainer/__pycache__/single_trainer.cpython-311.pyc b/trainer/__pycache__/single_trainer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bfbd974b57fc9e9abe1ee459c27d16e5465c94d
Binary files /dev/null and b/trainer/__pycache__/single_trainer.cpython-311.pyc differ
diff --git a/trainer/magic_aggregator.py b/trainer/magic_aggregator.py
new file mode 100755
index 0000000000000000000000000000000000000000..a9d7f70c78d64347b655494de696050e82d0fa6f
--- /dev/null
+++ b/trainer/magic_aggregator.py
@@ -0,0 +1,127 @@
+import logging
+
+import numpy as np
+import torch
+import wandb
+from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
+from utils.config import build_args
+from fedml.core import ServerAggregator
+from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
+from utils.poolers import Pooling
+# Trainer for MoleculeNet. The evaluation metric is ROC-AUC
+from data.data_loader import load_batch_level_dataset_main, load_metadata, load_entity_level_dataset
+from utils.loaddata import  load_batch_level_dataset
+
+class MagicWgetAggregator(ServerAggregator):
+    def __init__(self, model, args, name):
+        super().__init__(model, args)
+        self.name = name
+        
+    def get_model_params(self):
+        return self.model.cpu().state_dict()
+
+    def set_model_params(self, model_parameters):
+        logging.info("set_model_params")
+        self.model.load_state_dict(model_parameters)
+
+    def test(self, test_data, device, args):
+        pass
+
+    def test_all(self, train_data_local_dict, test_data_local_dict, device, args) -> bool:
+        logging.info("----------test_on_the_server--------")
+
+        model_list, score_list = [], []
+        for client_idx in test_data_local_dict.keys():
+            test_data = test_data_local_dict[client_idx]
+            score, model = self._test(test_data, device, args)
+            for idx in range(len(model_list)):
+                self._compare_models(model, model_list[idx])
+            model_list.append(model)
+            score_list.append(score)
+            logging.info("Client {}, Test ROC-AUC score = {}".format(client_idx, score))
+            if args.enable_wandb:
+                wandb.log({"Client {} Test/ROC-AUC".format(client_idx): score})
+        avg_score = np.mean(np.array(score_list))
+        logging.info("Test ROC-AUC Score = {}".format(avg_score))
+        if args.enable_wandb:
+            wandb.log({"Test/ROC-AUC": avg_score})
+        return True
+
+    def _compare_models(self, model_1, model_2):
+        models_differ = 0
+        for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
+            if torch.equal(key_item_1[1], key_item_2[1]):
+                pass
+            else:
+                models_differ += 1
+                if key_item_1[0] == key_item_2[0]:
+                    logging.info("Mismatch found at", key_item_1[0])
+                else:
+                    raise Exception
+        if models_differ == 0:
+            logging.info("Models match perfectly! :)")
+
+    def _test(self, test_data, device, args):
+        args = build_args()           
+        if (self.name == 'wget' or self.name == 'streamspot'):
+            logging.info("----------test--------")
+                 
+            model = self.model
+            model.eval()
+            model.to(device)
+            pooler = Pooling(args["pooling"])
+            dataset = load_batch_level_dataset(self.name)
+            n_node_feat = dataset['n_feat']
+            n_edge_feat = dataset['e_feat']
+            args["n_dim"] = n_node_feat
+            args["e_dim"] = n_edge_feat
+            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"],  args["e_dim"])
+        else:
+            torch.save(self.model.state_dict(), "./result/FedAvg-4client-{}.pt".format(self.name))
+            metadata = load_metadata(self.name)
+            args["n_dim"] = metadata['node_feature_dim']
+            args["e_dim"] = metadata['edge_feature_dim']
+            model = self.model.to(device)
+            model.eval()
+            malicious, _ = metadata['malicious']
+            n_train = metadata['n_train']
+            n_test = metadata['n_test']
+
+            with torch.no_grad():
+                x_train = []
+                for i in range(n_train):
+                   g = load_entity_level_dataset(self.name, 'train', i).to(device)
+                   x_train.append(model.embed(g).cpu().detach().numpy())
+                   del g
+                x_train = np.concatenate(x_train, axis=0)
+                skip_benign = 0
+            x_test = []
+            for i in range(n_test):
+                g = load_entity_level_dataset(self.name, 'test', i).to(device)
+                # Exclude training samples from the test set
+                if i != n_test - 1:
+                    skip_benign += g.number_of_nodes()
+                x_test.append(model.embed(g).cpu().detach().numpy())
+                del g
+            x_test = np.concatenate(x_test, axis=0)
+
+            n = x_test.shape[0]
+            y_test = np.zeros(n)
+            y_test[malicious] = 1.0
+            malicious_dict = {}
+            for i, m in enumerate(malicious):
+                malicious_dict[m] = i
+
+            # Exclude training samples from the test set
+            test_idx = []
+            for i in range(x_test.shape[0]):
+                if i >= skip_benign or y_test[i] == 1.0:
+                    test_idx.append(i)
+            result_x_test = x_test[test_idx]
+            result_y_test = y_test[test_idx]
+            del x_test, y_test
+            test_auc, test_std, _, _ = evaluate_entity_level_using_knn(self.name, x_train, result_x_test,
+                                                                       result_y_test)
+        torch.save(model.state_dict(), "./result/FedAvg-{}.pt".format(self.name))
+        return test_auc, model
+                                                    
diff --git a/trainer/magic_trainer.py b/trainer/magic_trainer.py
new file mode 100755
index 0000000000000000000000000000000000000000..99c9a5b49e31a631f30053db0c3cf0623569b2eb
--- /dev/null
+++ b/trainer/magic_trainer.py
@@ -0,0 +1,186 @@
+import logging
+import os
+import random
+import torch
+import warnings
+from tqdm import tqdm
+import numpy as np
+import torch
+import wandb
+from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
+
+from fedml.core import ClientTrainer
+from model.autoencoder import build_model
+from torch.utils.data.sampler import SubsetRandomSampler
+from dgl.dataloading import GraphDataLoader
+from model.train import batch_level_train
+from utils.utils import set_random_seed, create_optimizer
+from utils.poolers import Pooling
+from utils.config import build_args
+from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata
+from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
+from model.autoencoder import build_model
+from data.data_loader import transform_data
+from utils.loaddata import load_batch_level_dataset
+
+# Trainer for MoleculeNet. The evaluation metric is ROC-AUC
+def extract_dataloaders(entries, batch_size):
+    random.shuffle(entries)
+    train_idx = torch.arange(len(entries))
+    train_sampler = SubsetRandomSampler(train_idx)
+    train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler)
+    return train_loader
+
+
+
+class MagicTrainer(ClientTrainer):
+    def __init__(self, model, args, name):
+        super().__init__(model, args)
+        self.name = name
+        self.max = 0
+    	
+    def get_model_params(self):
+        return self.model.cpu().state_dict()
+
+    def set_model_params(self, model_parameters):
+        logging.info("set_model_params")
+        self.model.load_state_dict(model_parameters)
+
+    def train(self, train_data, device, args):
+        test_data = None
+        args = build_args()
+        if (self.name == "wget"):
+            args["num_hidden"] = 256
+            args["max_epoch"] = 2
+            args["num_layers"] = 4
+            batch_size = 1
+        elif (self.name == "streamspot"):
+            args["num_hidden"] = 256
+            args["max_epoch"] = 5
+            args["num_layers"] = 4
+            batch_size = 12
+        else:
+            args["num_hidden"] = 64
+            args["max_epoch"] = 50
+            args["num_layers"] = 3
+        
+        max_test_score = 0
+        best_model_params = {}       
+        if (self.name == 'wget' or self.name == 'streamspot'):
+ 
+            dataset = load_batch_level_dataset(self.name)
+            data = transform_data(train_data)
+            n_node_feat = dataset['n_feat']
+            n_edge_feat = dataset['e_feat']
+            graphs = data['dataset']
+            train_index = data['train_index']
+            args["n_dim"] = n_node_feat
+            args["e_dim"] = n_edge_feat
+        #self.model = build_model(args)
+            self.model = self.model.to(device)
+            optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"])
+            self.model = batch_level_train(self.model, graphs, (extract_dataloaders(train_index, batch_size)),
+                                  optimizer, args["max_epoch"], device, n_node_feat,   n_edge_feat)
+            test_score, _ = self.test(test_data, device, args)
+        else:
+            
+            metadata = load_metadata(self.name)
+            args["n_dim"] = metadata['node_feature_dim']
+            args["e_dim"] = metadata['edge_feature_dim']
+            self.model = self.model.to(device)
+            self.model.train()
+            optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"])
+            epoch_iter = tqdm(range(args["max_epoch"]))
+            n_train = len(train_data[0])
+            input("start?")
+            for epoch in epoch_iter:
+                epoch_loss = 0.0
+                for i in range(n_train):
+                    g = train_data[0][i]
+                    self.model.train()
+                    loss  = self.model(g)
+                    loss /= n_train
+                    optimizer.zero_grad()
+                    epoch_loss += loss.item()
+                    loss.backward(retain_graph=True)
+                    optimizer.step()
+                    del g
+            epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
+        if (self.name == 'wget' or self.name == 'streamspot'):    
+            test_score, _ = self.test(test_data, device, args)
+            if test_score > self.max:
+                self.max = test_score
+                best_model_params = {
+                k: v.cpu() for k, v in self.model.state_dict().items()
+                }
+        else:
+            self.max = 0
+            best_model_params = {
+            k: v.cpu() for k, v in self.model.state_dict().items()
+            }
+            
+     
+
+        return self.max, best_model_params
+
+    def test(self, test_data, device, args):
+        if (self.name == 'wget' or self.name == 'streamspot'):
+            logging.info("----------test--------")
+            args = build_args()        
+            model = self.model
+            model.eval()
+            model.to(device)
+            pooler = Pooling(args["pooling"])
+            dataset = load_batch_level_dataset(self.name)
+            n_node_feat = dataset['n_feat']
+            n_edge_feat = dataset['e_feat']
+            args["n_dim"] = n_node_feat
+            args["e_dim"] = n_edge_feat
+            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"],  args["e_dim"])
+        else:
+            metadata = load_metadata(self.name)
+            args["n_dim"] = metadata['node_feature_dim']
+            args["e_dim"] = metadata['edge_feature_dim']
+            model = self.model.to(device)
+            model.eval()
+            malicious, _ = metadata['malicious']
+            n_train = metadata['n_train']
+            n_test = metadata['n_test']
+            with torch.no_grad():
+                x_train = []
+                for i in range(n_train):
+                   g = load_entity_level_dataset(self.name, 'train', i).to(device)
+                   x_train.append(model.embed(g).cpu().detach().numpy())
+                   del g
+                x_train = np.concatenate(x_train, axis=0)
+                skip_benign = 0
+            x_test = []
+            for i in range(n_test):
+                g = load_entity_level_dataset(self.name, 'test', i).to(device)
+                # Exclude training samples from the test set
+                if i != n_test - 1:
+                    skip_benign += g.number_of_nodes()
+                x_test.append(model.embed(g).cpu().detach().numpy())
+                del g
+            x_test = np.concatenate(x_test, axis=0)
+
+            n = x_test.shape[0]
+            y_test = np.zeros(n)
+            y_test[malicious] = 1.0
+            malicious_dict = {}
+            for i, m in enumerate(malicious):
+                malicious_dict[m] = i
+
+            # Exclude training samples from the test set
+            test_idx = []
+            for i in range(x_test.shape[0]):
+                if i >= skip_benign or y_test[i] == 1.0:
+                    test_idx.append(i)
+            result_x_test = x_test[test_idx]
+            result_y_test = y_test[test_idx]
+            del x_test, y_test
+            test_auc, test_std, _, _ = evaluate_entity_level_using_knn(self.name, x_train, result_x_test,
+                                                                       result_y_test)
+
+        return test_auc, model
+
diff --git a/trainer/single_trainer.py b/trainer/single_trainer.py
new file mode 100755
index 0000000000000000000000000000000000000000..19f2686a46ed4aafdeeffc94219bf670119684cd
--- /dev/null
+++ b/trainer/single_trainer.py
@@ -0,0 +1,44 @@
+import os
+import random
+import torch
+import warnings
+from tqdm import tqdm
+from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata
+from model.autoencoder import build_model
+from torch.utils.data.sampler import SubsetRandomSampler
+from dgl.dataloading import GraphDataLoader
+from model.train import batch_level_train
+from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
+from utils.utils import set_random_seed, create_optimizer
+from utils.config import build_args
+from utils.poolers import Pooling
+warnings.filterwarnings('ignore')
+
+
+def extract_dataloaders(entries, batch_size):
+    random.shuffle(entries)
+    train_idx = torch.arange(len(entries))
+    train_sampler = SubsetRandomSampler(train_idx)
+    train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler)
+    return train_loader
+
+
+def train_single(main_args, model, dataset):
+    device = "cpu"
+    set_random_seed(0)
+    batch_size = 1
+    n_node_feat = dataset['n_feat']
+    n_edge_feat = dataset['e_feat']
+    graphs = dataset['dataset']
+    train_index = dataset['train_index']
+    model = model.to(device)
+    optimizer = create_optimizer(main_args["optimizer"], model, main_args["lr"], main_args["weight_decay"])
+    model = batch_level_train(model, graphs, (extract_dataloaders(train_index, batch_size)),
+                                  optimizer, main_args["max_epoch"], device, main_args["n_dim"], main_args["e_dim"])
+    #torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format("wget"))
+    pooler = Pooling(main_args["pooling"])
+    test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], "wget" ,main_args["n_dim"],  main_args["e_dim"])
+    return test_auc, model
+
+
+
diff --git a/trainer/utils/__pycache__/config.cpython-311.pyc b/trainer/utils/__pycache__/config.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..233d940b3ec7d7070e2ee889729ac7592675d3bf
Binary files /dev/null and b/trainer/utils/__pycache__/config.cpython-311.pyc differ
diff --git a/trainer/utils/config.py b/trainer/utils/config.py
new file mode 100755
index 0000000000000000000000000000000000000000..bb9c1ef28d385b80c7673027b4f69b31c36b1175
--- /dev/null
+++ b/trainer/utils/config.py
@@ -0,0 +1,20 @@
+import argparse
+
+
+def build_args():
+    parser = argparse.ArgumentParser(description="MAGIC")
+    parser.add_argument("--dataset", type=str, default="wget")
+    parser.add_argument("--device", type=int, default=-1)
+    parser.add_argument("--lr", type=float, default=0.001,
+                        help="learning rate")
+    parser.add_argument("--weight_decay", type=float, default=5e-4,
+                        help="weight decay")
+    parser.add_argument("--negative_slope", type=float, default=0.2,
+                        help="the negative slope of leaky relu for GAT")
+    parser.add_argument("--mask_rate", type=float, default=0.5)
+    parser.add_argument("--alpha_l", type=float, default=3, help="`pow`inddex for `sce` loss")
+    parser.add_argument("--optimizer", type=str, default="adam")
+    parser.add_argument("--loss_fn", type=str, default='sce')
+    parser.add_argument("--pooling", type=str, default="mean")
+    args = parser.parse_args()
+    return args
diff --git a/trainer/utils/loaddata.py b/trainer/utils/loaddata.py
new file mode 100755
index 0000000000000000000000000000000000000000..41e7dfc03ee39adcc1f5729d59aa21124d981fff
--- /dev/null
+++ b/trainer/utils/loaddata.py
@@ -0,0 +1,197 @@
+import pickle as pkl
+import time
+import torch.nn.functional as F
+import dgl
+import networkx as nx
+import json
+from tqdm import tqdm
+import os
+
+
+class StreamspotDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(StreamspotDataset, self).__init__(name=name)
+        if name == 'streamspot':
+            path = './data/streamspot'
+            num_graphs = 600
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                if 300 <= idx <= 399:
+                    self.labels.append(1)
+                else:
+                    self.labels.append(0)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+
+class WgetDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(WgetDataset, self).__init__(name=name)
+        if name == 'wget':
+            path = './data/wget/final'
+            num_graphs = 150
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                if 0 <= idx <= 24:
+                    self.labels.append(1)
+                else:
+                    self.labels.append(0)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+
+def load_rawdata(name):
+    if name == 'streamspot':
+        path = './data/streamspot'
+        if os.path.exists(path + '/graphs.pkl'):
+            print('Loading processed {} dataset...'.format(name))
+            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
+        else:
+            raw_data = StreamspotDataset(name)
+            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
+    elif name == 'wget':
+        path = './data/wget'
+        if os.path.exists(path + '/graphs.pkl'):
+            print('Loading processed {} dataset...'.format(name))
+            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
+        else:
+            raw_data = WgetDataset(name)
+            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
+    else:
+        raise NotImplementedError
+    return raw_data
+
+
+def load_batch_level_dataset(dataset_name):
+    dataset = load_rawdata(dataset_name)
+    graph, _ = dataset[0]
+    node_feature_dim = 0
+    for g, _ in dataset:
+        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
+    edge_feature_dim = 0
+    for g, _ in dataset:
+        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
+    node_feature_dim += 1
+    edge_feature_dim += 1
+    full_dataset = [i for i in range(len(dataset))]
+    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
+    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
+
+    return {'dataset': dataset,
+            'train_index': train_dataset,
+            'full_index': full_dataset,
+            'n_feat': node_feature_dim,
+            'e_feat': edge_feature_dim}
+
+
+def transform_graph(g, node_feature_dim, edge_feature_dim):
+    new_g = g.clone()
+    new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float()
+    new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float()
+    return new_g
+
+
+def preload_entity_level_dataset(path):
+    path = './data/' + path
+    if os.path.exists(path + '/metadata.json'):
+        pass
+    else:
+        print('transforming')
+        train_gs = [dgl.from_networkx(
+            nx.node_link_graph(g),
+            node_attrs=['type'],
+            edge_attrs=['type']
+        ) for g in pkl.load(open(path + '/train.pkl', 'rb'))]
+        print('transforming')
+        test_gs = [dgl.from_networkx(
+            nx.node_link_graph(g),
+            node_attrs=['type'],
+            edge_attrs=['type']
+        ) for g in pkl.load(open(path + '/test.pkl', 'rb'))]
+        malicious = pkl.load(open(path + '/malicious.pkl', 'rb'))
+
+        node_feature_dim = 0
+        for g in train_gs:
+            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
+        for g in test_gs:
+            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
+        node_feature_dim += 1
+        edge_feature_dim = 0
+        for g in train_gs:
+            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
+        for g in test_gs:
+            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
+        edge_feature_dim += 1
+        result_test_gs = []
+        for g in test_gs:
+            g = transform_graph(g, node_feature_dim, edge_feature_dim)
+            result_test_gs.append(g)
+        result_train_gs = []
+        for g in train_gs:
+            g = transform_graph(g, node_feature_dim, edge_feature_dim)
+            result_train_gs.append(g)
+        metadata = {
+            'node_feature_dim': node_feature_dim,
+            'edge_feature_dim': edge_feature_dim,
+            'malicious': malicious,
+            'n_train': len(result_train_gs),
+            'n_test': len(result_test_gs)
+        }
+        with open(path + '/metadata.json', 'w', encoding='utf-8') as f:
+            json.dump(metadata, f)
+        for i, g in enumerate(result_train_gs):
+            with open(path + '/train{}.pkl'.format(i), 'wb') as f:
+                pkl.dump(g, f)
+        for i, g in enumerate(result_test_gs):
+            with open(path + '/test{}.pkl'.format(i), 'wb') as f:
+                pkl.dump(g, f)
+
+
+def load_metadata(path):
+    preload_entity_level_dataset(path)
+    with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f:
+        metadata = json.load(f)
+    return metadata
+
+
+def load_entity_level_dataset(path, t, n):
+    preload_entity_level_dataset(path)
+    with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f:
+        data = pkl.load(f)
+    return data
diff --git a/trainer/utils/poolers.py b/trainer/utils/poolers.py
new file mode 100755
index 0000000000000000000000000000000000000000..4b9dc4aee1eac42f578952596a5e400adbd1e391
--- /dev/null
+++ b/trainer/utils/poolers.py
@@ -0,0 +1,43 @@
+import torch.nn as nn
+
+
+class Pooling(nn.Module):
+    def __init__(self, pooler):
+        super(Pooling, self).__init__()
+        self.pooler = pooler
+
+    def forward(self, graph, feat, t=None):
+        feat = feat
+        # Implement node type-specific pooling
+        with graph.local_scope():
+            if t is None:
+                if self.pooler == 'mean':
+                    return feat.mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat.sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat.max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
+            elif isinstance(t, int):
+                mask = (graph.ndata['type'] == t)
+                if self.pooler == 'mean':
+                    return feat[mask].mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat[mask].sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat[mask].max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
+            else:
+                mask = (graph.ndata['type'] == t[0])
+                for i in range(1, len(t)):
+                    mask |= (graph.ndata['type'] == t[i])
+                if self.pooler == 'mean':
+                    return feat[mask].mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat[mask].sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat[mask].max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
diff --git a/trainer/utils/streamspot_parser.py b/trainer/utils/streamspot_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..04438eb0f6336ae2bad2015ad6caf4919f48f9a3
--- /dev/null
+++ b/trainer/utils/streamspot_parser.py
@@ -0,0 +1,53 @@
+import networkx as nx
+from tqdm import tqdm
+import json
+raw_path = '../data/streamspot/'
+
+NUM_GRAPHS = 600
+node_type_dict = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
+edge_type_dict = ['i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
+                  'q', 't', 'u', 'v', 'w', 'y', 'z', 'A', 'C', 'D', 'E', 'G']
+node_type_set = set(node_type_dict)
+edge_type_set = set(edge_type_dict)
+
+count_graph = 0
+with open(raw_path + 'all.tsv', 'r', encoding='utf-8') as f:
+    lines = f.readlines()
+    g = nx.DiGraph()
+    node_map = {}
+    count_node = 0
+    for line in tqdm(lines):
+        src, src_type, dst, dst_type, etype, graph_id = line.strip('\n').split('\t')
+        graph_id = int(graph_id)
+        if src_type not in node_type_set or dst_type not in node_type_set:
+            continue
+        if etype not in edge_type_set:
+            continue
+        if graph_id != count_graph:
+            count_graph += 1
+            for n in g.nodes():
+                g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
+            for e in g.edges():
+                g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
+            f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
+            json.dump(nx.node_link_data(g), f1)
+            assert graph_id == count_graph
+            g = nx.DiGraph()
+            count_node = 0
+        if src not in node_map:
+            node_map[src] = count_node
+            g.add_node(count_node, type=src_type)
+            count_node += 1
+        if dst not in node_map:
+            node_map[dst] = count_node
+            g.add_node(count_node, type=dst_type)
+            count_node += 1
+        if not g.has_edge(node_map[src], node_map[dst]):
+            g.add_edge(node_map[src], node_map[dst], type=etype)
+    count_graph += 1
+    for n in g.nodes():
+        g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
+    for e in g.edges():
+        g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
+    f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
+    json.dump(nx.node_link_data(g), f1)
diff --git a/trainer/utils/trace_parser.py b/trainer/utils/trace_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..76a91d424a3eea0bbec87ffb635a0d98b27d72f7
--- /dev/null
+++ b/trainer/utils/trace_parser.py
@@ -0,0 +1,288 @@
+import argparse
+import json
+import os
+import random
+import re
+
+from tqdm import tqdm
+import networkx as nx
+import pickle as pkl
+
+
+node_type_dict = {}
+edge_type_dict = {}
+node_type_cnt = 0
+edge_type_cnt = 0
+
+metadata = {
+    'trace':{
+        'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'],
+        'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4']
+    },
+    'theia':{
+            'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'],
+            'test': ['ta1-theia-e3-official-6r.json.8']
+    },
+    'cadets':{
+            'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'],
+            'test': ['ta1-cadets-e3-official-2.json']
+    }
+}
+
+
+pattern_uuid = re.compile(r'uuid\":\"(.*?)\"')
+pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_type = re.compile(r'type\":\"(.*?)\"')
+pattern_time = re.compile(r'timestampNanos\":(.*?),')
+pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"')
+pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"')
+pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"')
+
+adversarial = 'None'
+
+
+def read_single_graph(dataset, malicious, path, test=False):
+    global node_type_cnt, edge_type_cnt
+    g = nx.DiGraph()
+    print('converting {} ...'.format(path))
+    path = '../data/{}/'.format(dataset) + path + '.txt'
+    f = open(path, 'r')
+    lines = []
+    for l in f.readlines():
+        split_line = l.split('\t')
+        src, src_type, dst, dst_type, edge_type, ts = split_line
+        ts = int(ts)
+        if not test:
+            if src in malicious or dst in malicious:
+                if src in malicious and src_type != 'MemoryObject':
+                    continue
+                if dst in malicious and dst_type != 'MemoryObject':
+                    continue
+
+        if src_type not in node_type_dict:
+            node_type_dict[src_type] = node_type_cnt
+            node_type_cnt += 1
+        if dst_type not in node_type_dict:
+            node_type_dict[dst_type] = node_type_cnt
+            node_type_cnt += 1
+        if edge_type not in edge_type_dict:
+            edge_type_dict[edge_type] = edge_type_cnt
+            edge_type_cnt += 1
+        if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type:
+            lines.append([dst, src, dst_type, src_type, edge_type, ts])
+        else:
+            lines.append([src, dst, src_type, dst_type, edge_type, ts])
+    lines.sort(key=lambda l: l[5])
+
+    node_map = {}
+    node_type_map = {}
+    node_cnt = 0
+    node_list = []
+    for l in lines:
+        src, dst, src_type, dst_type, edge_type = l[:5]
+        src_type_id = node_type_dict[src_type]
+        dst_type_id = node_type_dict[dst_type]
+        edge_type_id = edge_type_dict[edge_type]
+        if src not in node_map:
+            if test:
+                if adversarial == 'MFE' or adversarial == 'MCE':
+                    if src in malicious:
+                        src_type_id = node_type_dict['FILE_OBJECT_FILE']
+            if not test:
+                if adversarial == 'BFP':
+                    if src not in malicious:
+                        i = random.randint(1, 20)
+                        if i == 1:
+                            src_type_id = node_type_dict['NetFlowObject']
+            node_map[src] = node_cnt
+            g.add_node(node_cnt, type=src_type_id)
+            node_list.append(src)
+            node_type_map[src] = src_type
+            node_cnt += 1
+        if dst not in node_map:
+            if test:
+                if adversarial == 'MFE' or adversarial == 'MCE':
+                    if dst in malicious:
+                        dst_type_id = node_type_dict['FILE_OBJECT_FILE']
+            if not test:
+                if adversarial == 'BFP':
+                    if dst not in malicious:
+                        i = random.randint(1, 20)
+                        if i == 1:
+                            dst_type_id = node_type_dict['NetFlowObject']
+            node_map[dst] = node_cnt
+            g.add_node(node_cnt, type=dst_type_id)
+            node_type_map[dst] = dst_type
+            node_list.append(dst)
+            node_cnt += 1
+        if not g.has_edge(node_map[src], node_map[dst]):
+            g.add_edge(node_map[src], node_map[dst], type=edge_type_id)
+    if (adversarial == 'MSE' or adversarial == 'MCE') and test:
+        for i, node in enumerate(node_list):
+            if node in malicious:
+                while True:
+                    another_node = random.choice(node_list)
+                    if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE'])
+                        print(node, another_node)
+                        break
+                    if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ'])
+                        print(node, another_node)
+                        break
+                    if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT'])
+                        print(node, another_node)
+                        break
+                    if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject':
+                        break
+    return node_map, g
+
+
+def preprocess_dataset(dataset):
+    id_nodetype_map = {}
+    id_nodename_map = {}
+    for file in os.listdir('../data/{}/'.format(dataset)):
+        if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file:
+            print('reading {} ...'.format(file))
+            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
+            for line in tqdm(f):
+                if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue
+                if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue
+                if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue
+                if len(pattern_uuid.findall(line)) == 0: print(line)
+                uuid = pattern_uuid.findall(line)[0]
+                subject_type = pattern_type.findall(line)
+
+                if len(subject_type) < 1:
+                    if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line:
+                        subject_type = 'MemoryObject'
+                    if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line:
+                        subject_type = 'NetFlowObject'
+                    if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line:
+                        subject_type = 'UnnamedPipeObject'
+                else:
+                    subject_type = subject_type[0]
+
+                if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']:
+                    continue
+                id_nodetype_map[uuid] = subject_type
+                if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_file_name.findall(line)[0]
+                elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_process_name.findall(line)[0]
+                elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0]
+    for key in metadata[dataset]:
+        for file in metadata[dataset][key]:
+            if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'):
+                continue
+            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
+            fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8')
+            print('processing {} ...'.format(file))
+            for line in tqdm(f):
+                if 'com.bbn.tc.schema.avro.cdm18.Event' in line:
+                    edgeType = pattern_type.findall(line)[0]
+                    timestamp = pattern_time.findall(line)[0]
+                    srcId = pattern_src.findall(line)
+
+                    if len(srcId) == 0: continue
+                    srcId = srcId[0]
+                    if not srcId in id_nodetype_map:
+                        continue
+                    srcType = id_nodetype_map[srcId]
+                    dstId1 = pattern_dst1.findall(line)
+                    if len(dstId1) > 0 and dstId1[0] != 'null':
+                        dstId1 = dstId1[0]
+                        if not dstId1 in id_nodetype_map:
+                            continue
+                        dstType1 = id_nodetype_map[dstId1]
+                        this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str(
+                            dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
+                        fw.write(this_edge1)
+
+                    dstId2 = pattern_dst2.findall(line)
+                    if len(dstId2) > 0 and dstId2[0] != 'null':
+                        dstId2 = dstId2[0]
+                        if not dstId2 in id_nodetype_map.keys():
+                            continue
+                        dstType2 = id_nodetype_map[dstId2]
+                        this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str(
+                            dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
+                        fw.write(this_edge2)
+            fw.close()
+            f.close()
+    if len(id_nodename_map) != 0:
+        fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8')
+        json.dump(id_nodename_map, fw)
+    if len(id_nodetype_map) != 0:
+        fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8')
+        json.dump(id_nodetype_map, fw)
+
+
+def read_graphs(dataset):
+    malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset)
+    f = open(malicious_entities, 'r')
+    malicious_entities = set()
+    for l in f.readlines():
+        malicious_entities.add(l.lstrip().rstrip())
+
+    preprocess_dataset(dataset)
+    train_gs = []
+    for file in metadata[dataset]['train']:
+        _, train_g = read_single_graph(dataset, malicious_entities, file, False)
+        train_gs.append(train_g)
+    test_gs = []
+    test_node_map = {}
+    count_node = 0
+    for file in metadata[dataset]['test']:
+        node_map, test_g = read_single_graph(dataset, malicious_entities, file, True)
+        assert len(node_map) == test_g.number_of_nodes()
+        test_gs.append(test_g)
+        for key in node_map:
+            if key not in test_node_map:
+                test_node_map[key] = node_map[key] + count_node
+        count_node += test_g.number_of_nodes()
+
+    if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)):
+        with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f:
+            id_nodename_map = json.load(f)
+        with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f:
+            id_nodetype_map = json.load(f)
+        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
+        final_malicious_entities = []
+        malicious_names = []
+        for e in malicious_entities:
+            if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject':
+                final_malicious_entities.append(test_node_map[e])
+                if e in id_nodename_map:
+                    malicious_names.append(id_nodename_map[e])
+                    f.write('{}\t{}\n'.format(e, id_nodename_map[e]))
+                else:
+                    malicious_names.append(e)
+                    f.write('{}\t{}\n'.format(e, e))
+    else:
+        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
+        final_malicious_entities = []
+        malicious_names = []
+        for e in malicious_entities:
+            if e in test_node_map:
+                final_malicious_entities.append(test_node_map[e])
+                malicious_names.append(e)
+                f.write('{}\t{}\n'.format(e, e))
+
+    pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb'))
+    pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb'))
+    pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb'))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='CDM Parser')
+    parser.add_argument("--dataset", type=str, default="trace")
+    args = parser.parse_args()
+    if args.dataset not in ['trace', 'theia', 'cadets']:
+        raise NotImplementedError
+    read_graphs(args.dataset)
+
diff --git a/trainer/utils/utils.py b/trainer/utils/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..bcf9a481407696d73f30fe1dde279154a05702b3
--- /dev/null
+++ b/trainer/utils/utils.py
@@ -0,0 +1,112 @@
+import torch
+import torch.nn as nn
+from functools import partial
+import numpy as np
+import random
+import torch.optim as optim
+
+
+def create_optimizer(opt, model, lr, weight_decay):
+    opt_lower = opt.lower()
+    parameters = model.parameters()
+    opt_args = dict(lr=lr, weight_decay=weight_decay)
+    optimizer = None
+    opt_split = opt_lower.split("_")
+    opt_lower = opt_split[-1]
+    if opt_lower == "adam":
+        optimizer = optim.Adam(parameters, **opt_args)
+    elif opt_lower == "adamw":
+        optimizer = optim.AdamW(parameters, **opt_args)
+    elif opt_lower == "adadelta":
+        optimizer = optim.Adadelta(parameters, **opt_args)
+    elif opt_lower == "radam":
+        optimizer = optim.RAdam(parameters, **opt_args)
+    elif opt_lower == "sgd":
+        opt_args["momentum"] = 0.9
+        return optim.SGD(parameters, **opt_args)
+    else:
+        assert False and "Invalid optimizer"
+    return optimizer
+
+
+def random_shuffle(x, y):
+    idx = list(range(len(x)))
+    random.shuffle(idx)
+    return x[idx], y[idx]
+
+
+def set_random_seed(seed):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    torch.backends.cudnn.determinstic = True
+
+
+def create_activation(name):
+    if name == "relu":
+        return nn.ReLU()
+    elif name == "gelu":
+        return nn.GELU()
+    elif name == "prelu":
+        return nn.PReLU()
+    elif name is None:
+        return nn.Identity()
+    elif name == "elu":
+        return nn.ELU()
+    else:
+        raise NotImplementedError(f"{name} is not implemented.")
+
+
+def create_norm(name):
+    if name == "layernorm":
+        return nn.LayerNorm
+    elif name == "batchnorm":
+        return nn.BatchNorm1d
+    elif name == "graphnorm":
+        return partial(NormLayer, norm_type="groupnorm")
+    else:
+        return None
+
+
+class NormLayer(nn.Module):
+    def __init__(self, hidden_dim, norm_type):
+        super().__init__()
+        if norm_type == "batchnorm":
+            self.norm = nn.BatchNorm1d(hidden_dim)
+        elif norm_type == "layernorm":
+            self.norm = nn.LayerNorm(hidden_dim)
+        elif norm_type == "graphnorm":
+            self.norm = norm_type
+            self.weight = nn.Parameter(torch.ones(hidden_dim))
+            self.bias = nn.Parameter(torch.zeros(hidden_dim))
+
+            self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
+        else:
+            raise NotImplementedError
+
+    def forward(self, graph, x):
+        tensor = x
+        if self.norm is not None and type(self.norm) != str:
+            return self.norm(tensor)
+        elif self.norm is None:
+            return tensor
+
+        batch_list = graph.batch_num_nodes
+        batch_size = len(batch_list)
+        batch_list = torch.Tensor(batch_list).long().to(tensor.device)
+        batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
+        batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
+        mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
+        mean = mean.scatter_add_(0, batch_index, tensor)
+        mean = (mean.T / batch_list).T
+        mean = mean.repeat_interleave(batch_list, dim=0)
+
+        sub = tensor - mean * self.mean_scale
+
+        std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
+        std = std.scatter_add_(0, batch_index, sub.pow(2))
+        std = ((std.T / batch_list).T + 1e-6).sqrt()
+        std = std.repeat_interleave(batch_list, dim=0)
+        return self.weight * sub / std + self.bias
diff --git a/trainer/utils/wget_parser.py b/trainer/utils/wget_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..3ffbcffda1c5e9176cf1e5219b8c1e0ed1574055
--- /dev/null
+++ b/trainer/utils/wget_parser.py
@@ -0,0 +1,820 @@
+import argparse
+import json
+import os
+
+import xxhash
+import tqdm
+import logging
+import networkx as nx
+from tqdm import tqdm
+import time
+import datetime
+
+
+valid_node_type = ['file', 'process_memory', 'task', 'mmaped_file', 'path', 'socket', 'address', 'link']
+CONSOLE_ARGUMENTS = None
+
+
+def hashgen(l):
+    """Generate a single hash value from a list. @l is a list of
+    string values, which can be properties of a node/edge. This
+    function returns a single hashed integer value."""
+    hasher = xxhash.xxh64()
+    for e in l:
+        hasher.update(e)
+    return hasher.intdigest()
+
+
+def parse_nodes(json_string, node_map):
+    """Parse a CamFlow JSON string that may contain nodes ("activity" or "entity").
+    Parsed nodes populate @node_map, which is a dictionary that maps the node's UID,
+    which is assigned by CamFlow to uniquely identify a node object, to a hashed
+    value (in str) which represents the 'type' of the node. """
+    json_object = None
+    try:
+        # use "ignore" if non-decodeable exists in the @json_string
+        json_object = json.loads(json_string)
+    except Exception as e:
+        print("Exception ({}) occurred when parsing a node in JSON:".format(e))
+        print(json_string)
+        exit(1)
+    if "activity" in json_object:
+        activity = json_object["activity"]
+        for uid in activity:
+            if not uid in node_map:  # only parse unseen nodes
+                if "prov:type" not in activity[uid]:
+                    # a node must have a type.
+                    # record this issue if logging is turned on
+                    if CONSOLE_ARGUMENTS.verbose:
+                        logging.debug("skipping a problematic activity node with no 'prov:type': {}".format(uid))
+                else:
+                    node_map[uid] = activity[uid]["prov:type"]
+
+    if "entity" in json_object:
+        entity = json_object["entity"]
+        for uid in entity:
+            if not uid in node_map:
+                if "prov:type" not in entity[uid]:
+                    if CONSOLE_ARGUMENTS.verbose:
+                        logging.debug("skipping a problematic entity node with no 'prov:type': {}".format(uid))
+                else:
+                    node_map[uid] = entity[uid]["prov:type"]
+
+
+def parse_all_nodes(filename, node_map):
+    """Parse all nodes in CamFlow data. @filename is the file path of
+    the CamFlow data to parse. @node_map contains the mappings of all
+    CamFlow nodes to their hashed attributes. """
+    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing nodes in CamFlow data from {}'.format(filename)
+    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+    with open(filename, 'r') as f:
+        # each line in CamFlow data could contain multiple
+        # provenance nodes, we call @parse_nodes routine.
+        for line in f:
+            pb.update()  # for progress tracking
+            parse_nodes(line, node_map)
+    f.close()
+    pb.close()
+
+
+def parse_all_edges(inputfile, outputfile, node_map, noencode):
+    """Parse all edges (including their timestamp) from CamFlow data file @inputfile
+    to an @outputfile. Before this function is called, parse_all_nodes should be called
+    to populate the @node_map for all nodes in the CamFlow file. If @noencode is set,
+    we do not hash the nodes' original UUIDs generated by CamFlow to integers. This
+    function returns the total number of valid edge parsed from CamFlow dataset.
+
+    The output edgelist has the following format for each line, if -s is not set:
+        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>
+    If -s is set, each line would look like:
+        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>:<timestamp_stats>"""
+    total_edges = 0
+    smallest_timestamp = None
+    # scan through the entire file to find the smallest timestamp from all the edges.
+    # this step is only needed if we need to add some statistical information.
+    if CONSOLE_ARGUMENTS.stats:
+        description = '\x1b[6;30;42m[STATUS]\x1b[0m Scanning edges in CamFlow data from {}'.format(inputfile)
+        pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+        with open(inputfile, 'r') as f:
+            for line in f:
+                pb.update()
+                json_object = json.loads(line)
+
+                if "used" in json_object:
+                    used = json_object["used"]
+                    for uid in used:
+                        if "prov:type" not in used[uid]:
+                            continue
+                        if "cf:date" not in used[uid]:
+                            continue
+                        if "prov:entity" not in used[uid]:
+                            continue
+                        if "prov:activity" not in used[uid]:
+                            continue
+                        srcUUID = used[uid]["prov:entity"]
+                        dstUUID = used[uid]["prov:activity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = used[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasGeneratedBy" in json_object:
+                    wasGeneratedBy = json_object["wasGeneratedBy"]
+                    for uid in wasGeneratedBy:
+                        if "prov:type" not in wasGeneratedBy[uid]:
+                            continue
+                        if "cf:date" not in wasGeneratedBy[uid]:
+                            continue
+                        if "prov:entity" not in wasGeneratedBy[uid]:
+                            continue
+                        if "prov:activity" not in wasGeneratedBy[uid]:
+                            continue
+                        srcUUID = wasGeneratedBy[uid]["prov:activity"]
+                        dstUUID = wasGeneratedBy[uid]["prov:entity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasGeneratedBy[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasInformedBy" in json_object:
+                    wasInformedBy = json_object["wasInformedBy"]
+                    for uid in wasInformedBy:
+                        if "prov:type" not in wasInformedBy[uid]:
+                            continue
+                        if "cf:date" not in wasInformedBy[uid]:
+                            continue
+                        if "prov:informant" not in wasInformedBy[uid]:
+                            continue
+                        if "prov:informed" not in wasInformedBy[uid]:
+                            continue
+                        srcUUID = wasInformedBy[uid]["prov:informant"]
+                        dstUUID = wasInformedBy[uid]["prov:informed"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasInformedBy[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasDerivedFrom" in json_object:
+                    wasDerivedFrom = json_object["wasDerivedFrom"]
+                    for uid in wasDerivedFrom:
+                        if "prov:type" not in wasDerivedFrom[uid]:
+                            continue
+                        if "cf:date" not in wasDerivedFrom[uid]:
+                            continue
+                        if "prov:usedEntity" not in wasDerivedFrom[uid]:
+                            continue
+                        if "prov:generatedEntity" not in wasDerivedFrom[uid]:
+                            continue
+                        srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
+                        dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasDerivedFrom[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasAssociatedWith" in json_object:
+                    wasAssociatedWith = json_object["wasAssociatedWith"]
+                    for uid in wasAssociatedWith:
+                        if "prov:type" not in wasAssociatedWith[uid]:
+                            continue
+                        if "cf:date" not in wasAssociatedWith[uid]:
+                            continue
+                        if "prov:agent" not in wasAssociatedWith[uid]:
+                            continue
+                        if "prov:activity" not in wasAssociatedWith[uid]:
+                            continue
+                        srcUUID = wasAssociatedWith[uid]["prov:agent"]
+                        dstUUID = wasAssociatedWith[uid]["prov:activity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasAssociatedWith[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+        f.close()
+        pb.close()
+
+    # we will go through the CamFlow data (again) and output edgelist to a file
+    output = open(outputfile, "w+")
+    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in CamFlow data from {}'.format(inputfile)
+    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+    with open(inputfile, 'r') as f:
+        for line in f:
+            pb.update()
+            json_object = json.loads(line)
+
+            if "used" in json_object:
+                used = json_object["used"]
+                for uid in used:
+                    if "prov:type" not in used[uid]:
+                        # an edge must have a type; if not,
+                        # we will have to skip the edge. Log
+                        # this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "used"
+                    # cf:id is used as logical timestamp to order edges
+                    if "cf:id" not in used[uid]:
+                        # an edge must have a logical timestamp;
+                        # if not we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = used[uid]["cf:id"]
+                    if "prov:entity" not in used[uid]:
+                        # an edge's source node must exist;
+                        # if not, we will have to skip the
+                        # edge. Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record without source UUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in used[uid]:
+                        # an edge's destination node must exist;
+                        # if not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record without destination UUID: {}".format(used[uid]["prov:type"],
+                                                                                            uid))
+                        continue
+                    srcUUID = used[uid]["prov:entity"]
+                    dstUUID = used[uid]["prov:activity"]
+                    # both source and destination node must
+                    # exist in @node_map; if not, we will
+                    # have to skip the edge. Log this issue
+                    # if verbose is set.
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record with an unseen srcUUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record with an unseen dstUUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in used[uid]:
+                        # an edge must have a timestamp; if
+                        # not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        # we only record @adjusted_ts if we need
+                        # to record stats of CamFlow dataset.
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = used[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in used[uid]:
+                        # an edge must have a jiffies timestamp; if
+                        # not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        # we only record @jiffies if
+                        # the option is set
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = used[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp))
+
+            if "wasGeneratedBy" in json_object:
+                wasGeneratedBy = json_object["wasGeneratedBy"]
+                for uid in wasGeneratedBy:
+                    if "prov:type" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasGeneratedBy"
+                    if "cf:id" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasGeneratedBy[uid]["cf:id"]
+                    if "prov:entity" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record without source UUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record without destination UUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasGeneratedBy[uid]["prov:activity"]
+                    dstUUID = wasGeneratedBy[uid]["prov:entity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record with an unseen srcUUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record with an unsen dstUUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasGeneratedBy[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasGeneratedBy[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasInformedBy" in json_object:
+                wasInformedBy = json_object["wasInformedBy"]
+                for uid in wasInformedBy:
+                    if "prov:type" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasInformedBy"
+                    if "cf:id" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasInformedBy[uid]["cf:id"]
+                    if "prov:informant" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record without source UUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    if "prov:informed" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record without destination UUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasInformedBy[uid]["prov:informant"]
+                    dstUUID = wasInformedBy[uid]["prov:informed"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record with an unseen srcUUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record with an unseen dstUUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasInformedBy[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasInformedBy[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasDerivedFrom" in json_object:
+                wasDerivedFrom = json_object["wasDerivedFrom"]
+                for uid in wasDerivedFrom:
+                    if "prov:type" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasDerivedFrom"
+                    if "cf:id" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without logical timestamp: {}".format(uid))
+                            continue
+                    else:
+                        timestamp = wasDerivedFrom[uid]["cf:id"]
+                    if "prov:usedEntity" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record without source UUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    if "prov:generatedEntity" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record without destination UUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
+                    dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen srcUUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen dstUUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasDerivedFrom[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasDerivedFrom[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasAssociatedWith" in json_object:
+                wasAssociatedWith = json_object["wasAssociatedWith"]
+                for uid in wasAssociatedWith:
+                    if "prov:type" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasAssociatedWith"
+                    if "cf:id" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasAssociatedWith[uid]["cf:id"]
+                    if "prov:agent" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record without source UUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record without destination UUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasAssociatedWith[uid]["prov:agent"]
+                    dstUUID = wasAssociatedWith[uid]["prov:activity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen srcUUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen dstUUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasAssociatedWith[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasAssociatedWith[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+    f.close()
+    output.close()
+    pb.close()
+    return total_edges
+
+def read_single_graph(file_name, threshold):
+    graph = []
+    edge_cnt = 0
+    with open(file_name, 'r') as f:
+        for line in f:
+            try:
+                edge = line.strip().split("\t")
+                new_edge = [edge[0], edge[1]]
+                attributes = edge[2].strip().split(":")
+                source_node_type = attributes[0]
+                destination_node_type = attributes[1]
+                edge_type = attributes[2]
+                edge_order = attributes[3]
+
+                new_edge.append(source_node_type)
+                new_edge.append(destination_node_type)
+                new_edge.append(edge_type)
+                new_edge.append(edge_order)
+                graph.append(new_edge)
+                edge_cnt += 1
+            except:
+                print("{}".format(line))
+    f.close()
+    graph.sort(key=lambda e: e[5])
+    if len(graph) <= threshold:
+        return graph
+    else:
+        return graph[:threshold]
+
+
+def process_graph(name, threshold):
+    graph = read_single_graph(name, threshold)
+    result_graph = nx.DiGraph()
+    cnt = 0
+    for num, edge in enumerate(graph):
+        cnt += 1
+        src, dst, src_type, dst_type, edge_type = edge[:5]
+        if True:# src_type in valid_node_type and dst_type in valid_node_type:
+            if not result_graph.has_node(src):
+                result_graph.add_node(src, type=src_type)
+            if not result_graph.has_node(dst):
+                result_graph.add_node(dst, type=dst_type)
+            if not result_graph.has_edge(src, dst):
+                result_graph.add_edge(src, dst, type=edge_type)
+                if bidirection:
+                    result_graph.add_edge(dst, src, type='reverse_{}'.format(edge_type))
+    return cnt, result_graph
+
+
+node_type_list = []
+edge_type_list = []
+node_type_dict = {}
+edge_type_dict = {}
+
+
+def format_graph(g, name):
+    new_g = nx.DiGraph()
+    node_map = {}
+    node_cnt = 0
+    for n in g.nodes:
+        node_map[n] = node_cnt
+        new_g.add_node(node_cnt, type=g.nodes[n]['type'])
+        node_cnt += 1
+    for e in g.edges:
+        new_g.add_edge(node_map[e[0]], node_map[e[1]], type=g.edges[e]['type'])
+    for n in new_g.nodes:
+        node_type = new_g.nodes[n]['type']
+        if not node_type in node_type_dict:
+            node_type_list.append(node_type)
+            node_type_dict[node_type] = 1
+        else:
+            node_type_dict[node_type] += 1
+    for e in new_g.edges:
+        edge_type = new_g.edges[e]['type']
+        if not edge_type in edge_type_dict:
+            edge_type_list.append(edge_type)
+            edge_type_dict[edge_type] = 1
+        else:
+            edge_type_dict[edge_type] += 1
+    for n in new_g.nodes:
+        new_g.nodes[n]['type'] = node_type_list.index(new_g.nodes[n]['type'])
+    for e in new_g.edges:
+        new_g.edges[e]['type'] = edge_type_list.index(new_g.edges[e]['type'])
+    with open('{}.json'.format(name), 'w', encoding='utf-8') as f:
+        json.dump(nx.node_link_data(new_g), f)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Convert CamFlow JSON to Unicorn edgelist')
+    args = parser.parse_args()
+    args.stats = False
+    args.verbose = False
+    args.jiffies = False
+    args.input = '../data/wget/raw/'
+    args.output = '../data/wget/processed/'
+    args.final_output = '../data/wget/final/'
+    args.noencode = False
+    if not os.path.exists(args.input):
+        os.mkdir(args.input)
+    if not os.path.exists(args.output):
+        os.mkdir(args.output)
+    if not os.path.exists(args.final_output):
+        os.mkdir(args.final_output)
+    CONSOLE_ARGUMENTS = args
+
+    if args.verbose:
+        logging.basicConfig(filename=args.log, level=logging.DEBUG)
+    cnt = 0
+    for fname in os.listdir(args.input):
+        cnt += 1
+        node_map = dict()
+        parse_all_nodes(args.input + '/{}'.format(fname), node_map)
+        total_edges = parse_all_edges(args.input + '/{}'.format(fname), args.output + '/{}.log'.format(cnt), node_map,
+                                      args.noencode)
+        if args.stats:
+            total_nodes = len(node_map)
+            stats = open(args.stats_file + '/{}.log'.format(cnt), "a+")
+            stats.write("{},{},{}\n".format(args.input + '/{}'.format(fname), total_nodes, total_edges))
+
+    bidirection = False
+    threshold = 10000000 # infinity
+    interaction_dict = []
+    graph_cnt = 0
+    result_graphs = []
+    input = args.output
+    base = args.final_output
+
+    line_cnt = 0
+    for i in tqdm(range(cnt)):
+        single_cnt, result_graph = process_graph('{}{}.log'.format(input, i + 1), threshold)
+        format_graph(result_graph, '{}{}'.format(base, i))
+        line_cnt += single_cnt
+
+    print(line_cnt // 150)
+    print(len(node_type_list))
+    print(node_type_dict)
+    print(len(edge_type_list))
+    print(edge_type_dict)
+
diff --git a/utils/__pycache__/config.cpython-311.pyc b/utils/__pycache__/config.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..55cf6b81b72fe73af01f3176b76fd0081827f1e8
Binary files /dev/null and b/utils/__pycache__/config.cpython-311.pyc differ
diff --git a/utils/__pycache__/loaddata.cpython-311.pyc b/utils/__pycache__/loaddata.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..6dadf00a8f9a10b0fd4c0f82c44e99ec4cd3ffa0
Binary files /dev/null and b/utils/__pycache__/loaddata.cpython-311.pyc differ
diff --git a/utils/__pycache__/poolers.cpython-311.pyc b/utils/__pycache__/poolers.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..d761b00f03cf59c7afe8815f579b18fed2f7dd8f
Binary files /dev/null and b/utils/__pycache__/poolers.cpython-311.pyc differ
diff --git a/utils/__pycache__/utils.cpython-311.pyc b/utils/__pycache__/utils.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..f2d041d338d864c03c2baf1f746aaa8bba806202
Binary files /dev/null and b/utils/__pycache__/utils.cpython-311.pyc differ
diff --git a/utils/config.py b/utils/config.py
new file mode 100755
index 0000000000000000000000000000000000000000..ee14920f4b2e36dfcd0d7fddca100572837ade20
--- /dev/null
+++ b/utils/config.py
@@ -0,0 +1,13 @@
+def build_args():
+    args = {}
+    args["dataset"] ="wget"
+    args["device"]=-1
+    args["lr"]=0.001
+    args["weight_decay"]=5e-4
+    args["negative_slope"]=0.2
+    args["mask_rate"]=0.5
+    args["alpha_l"]=3
+    args["optimizer"]="adam"
+    args["loss_fn"]='sce'
+    args["pooling"] = "mean"
+    return args
diff --git a/utils/loaddata.py b/utils/loaddata.py
new file mode 100755
index 0000000000000000000000000000000000000000..ca48e0fd60aa69712f48c06f8083b37f679cf797
--- /dev/null
+++ b/utils/loaddata.py
@@ -0,0 +1,207 @@
+import pickle as pkl
+import time
+import torch.nn.functional as F
+import dgl
+import networkx as nx
+import json
+from tqdm import tqdm
+import os
+
+
+
+class StreamspotDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(StreamspotDataset, self).__init__(name=name)
+        if name == 'streamspot':
+            path = './data/streamspot'
+            num_graphs = 600
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                if 300 <= idx <= 399:
+                    self.labels.append(1)
+                else:
+                    self.labels.append(0)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+
+class WgetDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(WgetDataset, self).__init__(name=name)
+        if name == 'wget':
+            pathattack = '/data/wget/finalattack'
+            pathbenin = 'data/wget/finalbenin'
+            num_graphs_benin = 125
+            num_graphs_attack = 25
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs_benin)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                self.labels.append(0)
+                
+            for i in tqdm(range(num_graphs_attack)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                self.labels.append(1)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+
+def load_rawdata(name):
+    if name == 'streamspot':
+        path = './data/streamspot'
+        if os.path.exists(path + '/graphs.pkl'):
+            print('Loading processed {} dataset...'.format(name))
+            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
+        else:
+            raw_data = StreamspotDataset(name)
+            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
+    elif name == 'wget':
+        path = './data/wget'
+        if os.path.exists(path + '/graphs.pkl'):
+            print('Loading processed {} dataset...'.format(name))
+            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
+        else:
+            raw_data = WgetDataset(name)
+            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
+    else:
+        raise NotImplementedError
+    return raw_data
+
+
+def load_batch_level_dataset(dataset_name):
+    dataset = load_rawdata(dataset_name)
+    graph, _ = dataset[0]
+    node_feature_dim = 0
+    for g, _ in dataset:
+        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
+    edge_feature_dim = 0
+    for g, _ in dataset:
+        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
+    node_feature_dim += 1
+    edge_feature_dim += 1
+    full_dataset = [i for i in range(len(dataset))]
+    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
+    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
+
+    return {'dataset': dataset,
+            'train_index': train_dataset,
+            'full_index': full_dataset,
+            'n_feat': node_feature_dim,
+            'e_feat': edge_feature_dim}
+
+
+def transform_graph(g, node_feature_dim, edge_feature_dim):
+    new_g = g.clone()
+    new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float()
+    new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float()
+    return new_g
+
+
+def preload_entity_level_dataset(path):
+    path = './data/' + path
+    if os.path.exists(path + '/metadata.json'):
+        pass
+    else:
+        print('transforming')
+        train_gs = [dgl.from_networkx(
+            nx.node_link_graph(g),
+            node_attrs=['type'],
+            edge_attrs=['type']
+        ) for g in pkl.load(open(path + '/train.pkl', 'rb'))]
+        print('transforming')
+        test_gs = [dgl.from_networkx(
+            nx.node_link_graph(g),
+            node_attrs=['type'],
+            edge_attrs=['type']
+        ) for g in pkl.load(open(path + '/test.pkl', 'rb'))]
+        malicious = pkl.load(open(path + '/malicious.pkl', 'rb'))
+
+        node_feature_dim = 0
+        for g in train_gs:
+            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
+        for g in test_gs:
+            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
+        node_feature_dim += 1
+        edge_feature_dim = 0
+        for g in train_gs:
+            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
+        for g in test_gs:
+            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
+        edge_feature_dim += 1
+        result_test_gs = []
+        for g in test_gs:
+            g = transform_graph(g, node_feature_dim, edge_feature_dim)
+            result_test_gs.append(g)
+        result_train_gs = []
+        for g in train_gs:
+            g = transform_graph(g, node_feature_dim, edge_feature_dim)
+            result_train_gs.append(g)
+        metadata = {
+            'node_feature_dim': node_feature_dim,
+            'edge_feature_dim': edge_feature_dim,
+            'malicious': malicious,
+            'n_train': len(result_train_gs),
+            'n_test': len(result_test_gs)
+        }
+        with open(path + '/metadata.json', 'w', encoding='utf-8') as f:
+            json.dump(metadata, f)
+        for i, g in enumerate(result_train_gs):
+            with open(path + '/train{}.pkl'.format(i), 'wb') as f:
+                pkl.dump(g, f)
+        for i, g in enumerate(result_test_gs):
+            with open(path + '/test{}.pkl'.format(i), 'wb') as f:
+                pkl.dump(g, f)
+
+
+def load_metadata(path):
+    preload_entity_level_dataset(path)
+    with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f:
+        metadata = json.load(f)
+    return metadata
+
+
+def load_entity_level_dataset(path, t, n):
+    preload_entity_level_dataset(path)
+    with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f:
+        data = pkl.load(f)
+    return data
diff --git a/utils/poolers.py b/utils/poolers.py
new file mode 100755
index 0000000000000000000000000000000000000000..4b9dc4aee1eac42f578952596a5e400adbd1e391
--- /dev/null
+++ b/utils/poolers.py
@@ -0,0 +1,43 @@
+import torch.nn as nn
+
+
+class Pooling(nn.Module):
+    def __init__(self, pooler):
+        super(Pooling, self).__init__()
+        self.pooler = pooler
+
+    def forward(self, graph, feat, t=None):
+        feat = feat
+        # Implement node type-specific pooling
+        with graph.local_scope():
+            if t is None:
+                if self.pooler == 'mean':
+                    return feat.mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat.sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat.max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
+            elif isinstance(t, int):
+                mask = (graph.ndata['type'] == t)
+                if self.pooler == 'mean':
+                    return feat[mask].mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat[mask].sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat[mask].max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
+            else:
+                mask = (graph.ndata['type'] == t[0])
+                for i in range(1, len(t)):
+                    mask |= (graph.ndata['type'] == t[i])
+                if self.pooler == 'mean':
+                    return feat[mask].mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat[mask].sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat[mask].max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
diff --git a/utils/streamspot_parser.py b/utils/streamspot_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..07687828ea5ecf56f9859955e599afb584826f1f
--- /dev/null
+++ b/utils/streamspot_parser.py
@@ -0,0 +1,55 @@
+import networkx as nx
+from tqdm import tqdm
+import json
+raw_path = '../data/streamspot/'
+
+NUM_GRAPHS = 600
+node_type_dict = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
+edge_type_dict = ['i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
+                  'q', 't', 'u', 'v', 'w', 'y', 'z', 'A', 'C', 'D', 'E', 'G']
+node_type_set = set(node_type_dict)
+edge_type_set = set(edge_type_dict)
+print(raw_path)
+count_graph = 0
+with open(raw_path + 'all.tsv', 'r', encoding='utf-8') as f:
+    print("Reading")
+    lines = f.readlines()
+    print("starting")
+    g = nx.DiGraph()
+    node_map = {}
+    count_node = 0
+    for line in tqdm(lines):
+        src, src_type, dst, dst_type, etype, graph_id = line.strip('\n').split('\t')
+        graph_id = int(graph_id)
+        if src_type not in node_type_set or dst_type not in node_type_set:
+            continue
+        if etype not in edge_type_set:
+            continue
+        if graph_id != count_graph:
+            count_graph += 1
+            for n in g.nodes():
+                g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
+            for e in g.edges():
+                g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
+            f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
+            json.dump(nx.node_link_data(g), f1)
+            assert graph_id == count_graph
+            g = nx.DiGraph()
+            count_node = 0
+        if src not in node_map:
+            node_map[src] = count_node
+            g.add_node(count_node, type=src_type)
+            count_node += 1
+        if dst not in node_map:
+            node_map[dst] = count_node
+            g.add_node(count_node, type=dst_type)
+            count_node += 1
+        if not g.has_edge(node_map[src], node_map[dst]):
+            g.add_edge(node_map[src], node_map[dst], type=etype)
+    count_graph += 1
+    for n in g.nodes():
+        g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
+    for e in g.edges():
+        g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
+    f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
+    json.dump(nx.node_link_data(g), f1)
diff --git a/utils/trace_parser.py b/utils/trace_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..76a91d424a3eea0bbec87ffb635a0d98b27d72f7
--- /dev/null
+++ b/utils/trace_parser.py
@@ -0,0 +1,288 @@
+import argparse
+import json
+import os
+import random
+import re
+
+from tqdm import tqdm
+import networkx as nx
+import pickle as pkl
+
+
+node_type_dict = {}
+edge_type_dict = {}
+node_type_cnt = 0
+edge_type_cnt = 0
+
+metadata = {
+    'trace':{
+        'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'],
+        'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4']
+    },
+    'theia':{
+            'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'],
+            'test': ['ta1-theia-e3-official-6r.json.8']
+    },
+    'cadets':{
+            'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'],
+            'test': ['ta1-cadets-e3-official-2.json']
+    }
+}
+
+
+pattern_uuid = re.compile(r'uuid\":\"(.*?)\"')
+pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_type = re.compile(r'type\":\"(.*?)\"')
+pattern_time = re.compile(r'timestampNanos\":(.*?),')
+pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"')
+pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"')
+pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"')
+
+adversarial = 'None'
+
+
+def read_single_graph(dataset, malicious, path, test=False):
+    global node_type_cnt, edge_type_cnt
+    g = nx.DiGraph()
+    print('converting {} ...'.format(path))
+    path = '../data/{}/'.format(dataset) + path + '.txt'
+    f = open(path, 'r')
+    lines = []
+    for l in f.readlines():
+        split_line = l.split('\t')
+        src, src_type, dst, dst_type, edge_type, ts = split_line
+        ts = int(ts)
+        if not test:
+            if src in malicious or dst in malicious:
+                if src in malicious and src_type != 'MemoryObject':
+                    continue
+                if dst in malicious and dst_type != 'MemoryObject':
+                    continue
+
+        if src_type not in node_type_dict:
+            node_type_dict[src_type] = node_type_cnt
+            node_type_cnt += 1
+        if dst_type not in node_type_dict:
+            node_type_dict[dst_type] = node_type_cnt
+            node_type_cnt += 1
+        if edge_type not in edge_type_dict:
+            edge_type_dict[edge_type] = edge_type_cnt
+            edge_type_cnt += 1
+        if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type:
+            lines.append([dst, src, dst_type, src_type, edge_type, ts])
+        else:
+            lines.append([src, dst, src_type, dst_type, edge_type, ts])
+    lines.sort(key=lambda l: l[5])
+
+    node_map = {}
+    node_type_map = {}
+    node_cnt = 0
+    node_list = []
+    for l in lines:
+        src, dst, src_type, dst_type, edge_type = l[:5]
+        src_type_id = node_type_dict[src_type]
+        dst_type_id = node_type_dict[dst_type]
+        edge_type_id = edge_type_dict[edge_type]
+        if src not in node_map:
+            if test:
+                if adversarial == 'MFE' or adversarial == 'MCE':
+                    if src in malicious:
+                        src_type_id = node_type_dict['FILE_OBJECT_FILE']
+            if not test:
+                if adversarial == 'BFP':
+                    if src not in malicious:
+                        i = random.randint(1, 20)
+                        if i == 1:
+                            src_type_id = node_type_dict['NetFlowObject']
+            node_map[src] = node_cnt
+            g.add_node(node_cnt, type=src_type_id)
+            node_list.append(src)
+            node_type_map[src] = src_type
+            node_cnt += 1
+        if dst not in node_map:
+            if test:
+                if adversarial == 'MFE' or adversarial == 'MCE':
+                    if dst in malicious:
+                        dst_type_id = node_type_dict['FILE_OBJECT_FILE']
+            if not test:
+                if adversarial == 'BFP':
+                    if dst not in malicious:
+                        i = random.randint(1, 20)
+                        if i == 1:
+                            dst_type_id = node_type_dict['NetFlowObject']
+            node_map[dst] = node_cnt
+            g.add_node(node_cnt, type=dst_type_id)
+            node_type_map[dst] = dst_type
+            node_list.append(dst)
+            node_cnt += 1
+        if not g.has_edge(node_map[src], node_map[dst]):
+            g.add_edge(node_map[src], node_map[dst], type=edge_type_id)
+    if (adversarial == 'MSE' or adversarial == 'MCE') and test:
+        for i, node in enumerate(node_list):
+            if node in malicious:
+                while True:
+                    another_node = random.choice(node_list)
+                    if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE'])
+                        print(node, another_node)
+                        break
+                    if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ'])
+                        print(node, another_node)
+                        break
+                    if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT'])
+                        print(node, another_node)
+                        break
+                    if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject':
+                        break
+    return node_map, g
+
+
+def preprocess_dataset(dataset):
+    id_nodetype_map = {}
+    id_nodename_map = {}
+    for file in os.listdir('../data/{}/'.format(dataset)):
+        if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file:
+            print('reading {} ...'.format(file))
+            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
+            for line in tqdm(f):
+                if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue
+                if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue
+                if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue
+                if len(pattern_uuid.findall(line)) == 0: print(line)
+                uuid = pattern_uuid.findall(line)[0]
+                subject_type = pattern_type.findall(line)
+
+                if len(subject_type) < 1:
+                    if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line:
+                        subject_type = 'MemoryObject'
+                    if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line:
+                        subject_type = 'NetFlowObject'
+                    if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line:
+                        subject_type = 'UnnamedPipeObject'
+                else:
+                    subject_type = subject_type[0]
+
+                if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']:
+                    continue
+                id_nodetype_map[uuid] = subject_type
+                if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_file_name.findall(line)[0]
+                elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_process_name.findall(line)[0]
+                elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0]
+    for key in metadata[dataset]:
+        for file in metadata[dataset][key]:
+            if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'):
+                continue
+            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
+            fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8')
+            print('processing {} ...'.format(file))
+            for line in tqdm(f):
+                if 'com.bbn.tc.schema.avro.cdm18.Event' in line:
+                    edgeType = pattern_type.findall(line)[0]
+                    timestamp = pattern_time.findall(line)[0]
+                    srcId = pattern_src.findall(line)
+
+                    if len(srcId) == 0: continue
+                    srcId = srcId[0]
+                    if not srcId in id_nodetype_map:
+                        continue
+                    srcType = id_nodetype_map[srcId]
+                    dstId1 = pattern_dst1.findall(line)
+                    if len(dstId1) > 0 and dstId1[0] != 'null':
+                        dstId1 = dstId1[0]
+                        if not dstId1 in id_nodetype_map:
+                            continue
+                        dstType1 = id_nodetype_map[dstId1]
+                        this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str(
+                            dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
+                        fw.write(this_edge1)
+
+                    dstId2 = pattern_dst2.findall(line)
+                    if len(dstId2) > 0 and dstId2[0] != 'null':
+                        dstId2 = dstId2[0]
+                        if not dstId2 in id_nodetype_map.keys():
+                            continue
+                        dstType2 = id_nodetype_map[dstId2]
+                        this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str(
+                            dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
+                        fw.write(this_edge2)
+            fw.close()
+            f.close()
+    if len(id_nodename_map) != 0:
+        fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8')
+        json.dump(id_nodename_map, fw)
+    if len(id_nodetype_map) != 0:
+        fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8')
+        json.dump(id_nodetype_map, fw)
+
+
+def read_graphs(dataset):
+    malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset)
+    f = open(malicious_entities, 'r')
+    malicious_entities = set()
+    for l in f.readlines():
+        malicious_entities.add(l.lstrip().rstrip())
+
+    preprocess_dataset(dataset)
+    train_gs = []
+    for file in metadata[dataset]['train']:
+        _, train_g = read_single_graph(dataset, malicious_entities, file, False)
+        train_gs.append(train_g)
+    test_gs = []
+    test_node_map = {}
+    count_node = 0
+    for file in metadata[dataset]['test']:
+        node_map, test_g = read_single_graph(dataset, malicious_entities, file, True)
+        assert len(node_map) == test_g.number_of_nodes()
+        test_gs.append(test_g)
+        for key in node_map:
+            if key not in test_node_map:
+                test_node_map[key] = node_map[key] + count_node
+        count_node += test_g.number_of_nodes()
+
+    if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)):
+        with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f:
+            id_nodename_map = json.load(f)
+        with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f:
+            id_nodetype_map = json.load(f)
+        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
+        final_malicious_entities = []
+        malicious_names = []
+        for e in malicious_entities:
+            if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject':
+                final_malicious_entities.append(test_node_map[e])
+                if e in id_nodename_map:
+                    malicious_names.append(id_nodename_map[e])
+                    f.write('{}\t{}\n'.format(e, id_nodename_map[e]))
+                else:
+                    malicious_names.append(e)
+                    f.write('{}\t{}\n'.format(e, e))
+    else:
+        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
+        final_malicious_entities = []
+        malicious_names = []
+        for e in malicious_entities:
+            if e in test_node_map:
+                final_malicious_entities.append(test_node_map[e])
+                malicious_names.append(e)
+                f.write('{}\t{}\n'.format(e, e))
+
+    pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb'))
+    pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb'))
+    pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb'))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='CDM Parser')
+    parser.add_argument("--dataset", type=str, default="trace")
+    args = parser.parse_args()
+    if args.dataset not in ['trace', 'theia', 'cadets']:
+        raise NotImplementedError
+    read_graphs(args.dataset)
+
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..bcf9a481407696d73f30fe1dde279154a05702b3
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,112 @@
+import torch
+import torch.nn as nn
+from functools import partial
+import numpy as np
+import random
+import torch.optim as optim
+
+
+def create_optimizer(opt, model, lr, weight_decay):
+    opt_lower = opt.lower()
+    parameters = model.parameters()
+    opt_args = dict(lr=lr, weight_decay=weight_decay)
+    optimizer = None
+    opt_split = opt_lower.split("_")
+    opt_lower = opt_split[-1]
+    if opt_lower == "adam":
+        optimizer = optim.Adam(parameters, **opt_args)
+    elif opt_lower == "adamw":
+        optimizer = optim.AdamW(parameters, **opt_args)
+    elif opt_lower == "adadelta":
+        optimizer = optim.Adadelta(parameters, **opt_args)
+    elif opt_lower == "radam":
+        optimizer = optim.RAdam(parameters, **opt_args)
+    elif opt_lower == "sgd":
+        opt_args["momentum"] = 0.9
+        return optim.SGD(parameters, **opt_args)
+    else:
+        assert False and "Invalid optimizer"
+    return optimizer
+
+
+def random_shuffle(x, y):
+    idx = list(range(len(x)))
+    random.shuffle(idx)
+    return x[idx], y[idx]
+
+
+def set_random_seed(seed):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    torch.backends.cudnn.determinstic = True
+
+
+def create_activation(name):
+    if name == "relu":
+        return nn.ReLU()
+    elif name == "gelu":
+        return nn.GELU()
+    elif name == "prelu":
+        return nn.PReLU()
+    elif name is None:
+        return nn.Identity()
+    elif name == "elu":
+        return nn.ELU()
+    else:
+        raise NotImplementedError(f"{name} is not implemented.")
+
+
+def create_norm(name):
+    if name == "layernorm":
+        return nn.LayerNorm
+    elif name == "batchnorm":
+        return nn.BatchNorm1d
+    elif name == "graphnorm":
+        return partial(NormLayer, norm_type="groupnorm")
+    else:
+        return None
+
+
+class NormLayer(nn.Module):
+    def __init__(self, hidden_dim, norm_type):
+        super().__init__()
+        if norm_type == "batchnorm":
+            self.norm = nn.BatchNorm1d(hidden_dim)
+        elif norm_type == "layernorm":
+            self.norm = nn.LayerNorm(hidden_dim)
+        elif norm_type == "graphnorm":
+            self.norm = norm_type
+            self.weight = nn.Parameter(torch.ones(hidden_dim))
+            self.bias = nn.Parameter(torch.zeros(hidden_dim))
+
+            self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
+        else:
+            raise NotImplementedError
+
+    def forward(self, graph, x):
+        tensor = x
+        if self.norm is not None and type(self.norm) != str:
+            return self.norm(tensor)
+        elif self.norm is None:
+            return tensor
+
+        batch_list = graph.batch_num_nodes
+        batch_size = len(batch_list)
+        batch_list = torch.Tensor(batch_list).long().to(tensor.device)
+        batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
+        batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
+        mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
+        mean = mean.scatter_add_(0, batch_index, tensor)
+        mean = (mean.T / batch_list).T
+        mean = mean.repeat_interleave(batch_list, dim=0)
+
+        sub = tensor - mean * self.mean_scale
+
+        std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
+        std = std.scatter_add_(0, batch_index, sub.pow(2))
+        std = ((std.T / batch_list).T + 1e-6).sqrt()
+        std = std.repeat_interleave(batch_list, dim=0)
+        return self.weight * sub / std + self.bias
diff --git a/utils/wget_parser.py b/utils/wget_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..3ffbcffda1c5e9176cf1e5219b8c1e0ed1574055
--- /dev/null
+++ b/utils/wget_parser.py
@@ -0,0 +1,820 @@
+import argparse
+import json
+import os
+
+import xxhash
+import tqdm
+import logging
+import networkx as nx
+from tqdm import tqdm
+import time
+import datetime
+
+
+valid_node_type = ['file', 'process_memory', 'task', 'mmaped_file', 'path', 'socket', 'address', 'link']
+CONSOLE_ARGUMENTS = None
+
+
+def hashgen(l):
+    """Generate a single hash value from a list. @l is a list of
+    string values, which can be properties of a node/edge. This
+    function returns a single hashed integer value."""
+    hasher = xxhash.xxh64()
+    for e in l:
+        hasher.update(e)
+    return hasher.intdigest()
+
+
+def parse_nodes(json_string, node_map):
+    """Parse a CamFlow JSON string that may contain nodes ("activity" or "entity").
+    Parsed nodes populate @node_map, which is a dictionary that maps the node's UID,
+    which is assigned by CamFlow to uniquely identify a node object, to a hashed
+    value (in str) which represents the 'type' of the node. """
+    json_object = None
+    try:
+        # use "ignore" if non-decodeable exists in the @json_string
+        json_object = json.loads(json_string)
+    except Exception as e:
+        print("Exception ({}) occurred when parsing a node in JSON:".format(e))
+        print(json_string)
+        exit(1)
+    if "activity" in json_object:
+        activity = json_object["activity"]
+        for uid in activity:
+            if not uid in node_map:  # only parse unseen nodes
+                if "prov:type" not in activity[uid]:
+                    # a node must have a type.
+                    # record this issue if logging is turned on
+                    if CONSOLE_ARGUMENTS.verbose:
+                        logging.debug("skipping a problematic activity node with no 'prov:type': {}".format(uid))
+                else:
+                    node_map[uid] = activity[uid]["prov:type"]
+
+    if "entity" in json_object:
+        entity = json_object["entity"]
+        for uid in entity:
+            if not uid in node_map:
+                if "prov:type" not in entity[uid]:
+                    if CONSOLE_ARGUMENTS.verbose:
+                        logging.debug("skipping a problematic entity node with no 'prov:type': {}".format(uid))
+                else:
+                    node_map[uid] = entity[uid]["prov:type"]
+
+
+def parse_all_nodes(filename, node_map):
+    """Parse all nodes in CamFlow data. @filename is the file path of
+    the CamFlow data to parse. @node_map contains the mappings of all
+    CamFlow nodes to their hashed attributes. """
+    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing nodes in CamFlow data from {}'.format(filename)
+    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+    with open(filename, 'r') as f:
+        # each line in CamFlow data could contain multiple
+        # provenance nodes, we call @parse_nodes routine.
+        for line in f:
+            pb.update()  # for progress tracking
+            parse_nodes(line, node_map)
+    f.close()
+    pb.close()
+
+
+def parse_all_edges(inputfile, outputfile, node_map, noencode):
+    """Parse all edges (including their timestamp) from CamFlow data file @inputfile
+    to an @outputfile. Before this function is called, parse_all_nodes should be called
+    to populate the @node_map for all nodes in the CamFlow file. If @noencode is set,
+    we do not hash the nodes' original UUIDs generated by CamFlow to integers. This
+    function returns the total number of valid edge parsed from CamFlow dataset.
+
+    The output edgelist has the following format for each line, if -s is not set:
+        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>
+    If -s is set, each line would look like:
+        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>:<timestamp_stats>"""
+    total_edges = 0
+    smallest_timestamp = None
+    # scan through the entire file to find the smallest timestamp from all the edges.
+    # this step is only needed if we need to add some statistical information.
+    if CONSOLE_ARGUMENTS.stats:
+        description = '\x1b[6;30;42m[STATUS]\x1b[0m Scanning edges in CamFlow data from {}'.format(inputfile)
+        pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+        with open(inputfile, 'r') as f:
+            for line in f:
+                pb.update()
+                json_object = json.loads(line)
+
+                if "used" in json_object:
+                    used = json_object["used"]
+                    for uid in used:
+                        if "prov:type" not in used[uid]:
+                            continue
+                        if "cf:date" not in used[uid]:
+                            continue
+                        if "prov:entity" not in used[uid]:
+                            continue
+                        if "prov:activity" not in used[uid]:
+                            continue
+                        srcUUID = used[uid]["prov:entity"]
+                        dstUUID = used[uid]["prov:activity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = used[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasGeneratedBy" in json_object:
+                    wasGeneratedBy = json_object["wasGeneratedBy"]
+                    for uid in wasGeneratedBy:
+                        if "prov:type" not in wasGeneratedBy[uid]:
+                            continue
+                        if "cf:date" not in wasGeneratedBy[uid]:
+                            continue
+                        if "prov:entity" not in wasGeneratedBy[uid]:
+                            continue
+                        if "prov:activity" not in wasGeneratedBy[uid]:
+                            continue
+                        srcUUID = wasGeneratedBy[uid]["prov:activity"]
+                        dstUUID = wasGeneratedBy[uid]["prov:entity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasGeneratedBy[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasInformedBy" in json_object:
+                    wasInformedBy = json_object["wasInformedBy"]
+                    for uid in wasInformedBy:
+                        if "prov:type" not in wasInformedBy[uid]:
+                            continue
+                        if "cf:date" not in wasInformedBy[uid]:
+                            continue
+                        if "prov:informant" not in wasInformedBy[uid]:
+                            continue
+                        if "prov:informed" not in wasInformedBy[uid]:
+                            continue
+                        srcUUID = wasInformedBy[uid]["prov:informant"]
+                        dstUUID = wasInformedBy[uid]["prov:informed"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasInformedBy[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasDerivedFrom" in json_object:
+                    wasDerivedFrom = json_object["wasDerivedFrom"]
+                    for uid in wasDerivedFrom:
+                        if "prov:type" not in wasDerivedFrom[uid]:
+                            continue
+                        if "cf:date" not in wasDerivedFrom[uid]:
+                            continue
+                        if "prov:usedEntity" not in wasDerivedFrom[uid]:
+                            continue
+                        if "prov:generatedEntity" not in wasDerivedFrom[uid]:
+                            continue
+                        srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
+                        dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasDerivedFrom[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasAssociatedWith" in json_object:
+                    wasAssociatedWith = json_object["wasAssociatedWith"]
+                    for uid in wasAssociatedWith:
+                        if "prov:type" not in wasAssociatedWith[uid]:
+                            continue
+                        if "cf:date" not in wasAssociatedWith[uid]:
+                            continue
+                        if "prov:agent" not in wasAssociatedWith[uid]:
+                            continue
+                        if "prov:activity" not in wasAssociatedWith[uid]:
+                            continue
+                        srcUUID = wasAssociatedWith[uid]["prov:agent"]
+                        dstUUID = wasAssociatedWith[uid]["prov:activity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasAssociatedWith[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+        f.close()
+        pb.close()
+
+    # we will go through the CamFlow data (again) and output edgelist to a file
+    output = open(outputfile, "w+")
+    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in CamFlow data from {}'.format(inputfile)
+    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+    with open(inputfile, 'r') as f:
+        for line in f:
+            pb.update()
+            json_object = json.loads(line)
+
+            if "used" in json_object:
+                used = json_object["used"]
+                for uid in used:
+                    if "prov:type" not in used[uid]:
+                        # an edge must have a type; if not,
+                        # we will have to skip the edge. Log
+                        # this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "used"
+                    # cf:id is used as logical timestamp to order edges
+                    if "cf:id" not in used[uid]:
+                        # an edge must have a logical timestamp;
+                        # if not we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = used[uid]["cf:id"]
+                    if "prov:entity" not in used[uid]:
+                        # an edge's source node must exist;
+                        # if not, we will have to skip the
+                        # edge. Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record without source UUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in used[uid]:
+                        # an edge's destination node must exist;
+                        # if not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record without destination UUID: {}".format(used[uid]["prov:type"],
+                                                                                            uid))
+                        continue
+                    srcUUID = used[uid]["prov:entity"]
+                    dstUUID = used[uid]["prov:activity"]
+                    # both source and destination node must
+                    # exist in @node_map; if not, we will
+                    # have to skip the edge. Log this issue
+                    # if verbose is set.
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record with an unseen srcUUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record with an unseen dstUUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in used[uid]:
+                        # an edge must have a timestamp; if
+                        # not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        # we only record @adjusted_ts if we need
+                        # to record stats of CamFlow dataset.
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = used[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in used[uid]:
+                        # an edge must have a jiffies timestamp; if
+                        # not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        # we only record @jiffies if
+                        # the option is set
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = used[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp))
+
+            if "wasGeneratedBy" in json_object:
+                wasGeneratedBy = json_object["wasGeneratedBy"]
+                for uid in wasGeneratedBy:
+                    if "prov:type" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasGeneratedBy"
+                    if "cf:id" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasGeneratedBy[uid]["cf:id"]
+                    if "prov:entity" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record without source UUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record without destination UUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasGeneratedBy[uid]["prov:activity"]
+                    dstUUID = wasGeneratedBy[uid]["prov:entity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record with an unseen srcUUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record with an unsen dstUUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasGeneratedBy[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasGeneratedBy[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasInformedBy" in json_object:
+                wasInformedBy = json_object["wasInformedBy"]
+                for uid in wasInformedBy:
+                    if "prov:type" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasInformedBy"
+                    if "cf:id" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasInformedBy[uid]["cf:id"]
+                    if "prov:informant" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record without source UUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    if "prov:informed" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record without destination UUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasInformedBy[uid]["prov:informant"]
+                    dstUUID = wasInformedBy[uid]["prov:informed"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record with an unseen srcUUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record with an unseen dstUUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasInformedBy[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasInformedBy[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasDerivedFrom" in json_object:
+                wasDerivedFrom = json_object["wasDerivedFrom"]
+                for uid in wasDerivedFrom:
+                    if "prov:type" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasDerivedFrom"
+                    if "cf:id" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without logical timestamp: {}".format(uid))
+                            continue
+                    else:
+                        timestamp = wasDerivedFrom[uid]["cf:id"]
+                    if "prov:usedEntity" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record without source UUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    if "prov:generatedEntity" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record without destination UUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
+                    dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen srcUUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen dstUUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasDerivedFrom[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasDerivedFrom[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasAssociatedWith" in json_object:
+                wasAssociatedWith = json_object["wasAssociatedWith"]
+                for uid in wasAssociatedWith:
+                    if "prov:type" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasAssociatedWith"
+                    if "cf:id" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasAssociatedWith[uid]["cf:id"]
+                    if "prov:agent" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record without source UUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record without destination UUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasAssociatedWith[uid]["prov:agent"]
+                    dstUUID = wasAssociatedWith[uid]["prov:activity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen srcUUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen dstUUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasAssociatedWith[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasAssociatedWith[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+    f.close()
+    output.close()
+    pb.close()
+    return total_edges
+
+def read_single_graph(file_name, threshold):
+    graph = []
+    edge_cnt = 0
+    with open(file_name, 'r') as f:
+        for line in f:
+            try:
+                edge = line.strip().split("\t")
+                new_edge = [edge[0], edge[1]]
+                attributes = edge[2].strip().split(":")
+                source_node_type = attributes[0]
+                destination_node_type = attributes[1]
+                edge_type = attributes[2]
+                edge_order = attributes[3]
+
+                new_edge.append(source_node_type)
+                new_edge.append(destination_node_type)
+                new_edge.append(edge_type)
+                new_edge.append(edge_order)
+                graph.append(new_edge)
+                edge_cnt += 1
+            except:
+                print("{}".format(line))
+    f.close()
+    graph.sort(key=lambda e: e[5])
+    if len(graph) <= threshold:
+        return graph
+    else:
+        return graph[:threshold]
+
+
+def process_graph(name, threshold):
+    graph = read_single_graph(name, threshold)
+    result_graph = nx.DiGraph()
+    cnt = 0
+    for num, edge in enumerate(graph):
+        cnt += 1
+        src, dst, src_type, dst_type, edge_type = edge[:5]
+        if True:# src_type in valid_node_type and dst_type in valid_node_type:
+            if not result_graph.has_node(src):
+                result_graph.add_node(src, type=src_type)
+            if not result_graph.has_node(dst):
+                result_graph.add_node(dst, type=dst_type)
+            if not result_graph.has_edge(src, dst):
+                result_graph.add_edge(src, dst, type=edge_type)
+                if bidirection:
+                    result_graph.add_edge(dst, src, type='reverse_{}'.format(edge_type))
+    return cnt, result_graph
+
+
+node_type_list = []
+edge_type_list = []
+node_type_dict = {}
+edge_type_dict = {}
+
+
+def format_graph(g, name):
+    new_g = nx.DiGraph()
+    node_map = {}
+    node_cnt = 0
+    for n in g.nodes:
+        node_map[n] = node_cnt
+        new_g.add_node(node_cnt, type=g.nodes[n]['type'])
+        node_cnt += 1
+    for e in g.edges:
+        new_g.add_edge(node_map[e[0]], node_map[e[1]], type=g.edges[e]['type'])
+    for n in new_g.nodes:
+        node_type = new_g.nodes[n]['type']
+        if not node_type in node_type_dict:
+            node_type_list.append(node_type)
+            node_type_dict[node_type] = 1
+        else:
+            node_type_dict[node_type] += 1
+    for e in new_g.edges:
+        edge_type = new_g.edges[e]['type']
+        if not edge_type in edge_type_dict:
+            edge_type_list.append(edge_type)
+            edge_type_dict[edge_type] = 1
+        else:
+            edge_type_dict[edge_type] += 1
+    for n in new_g.nodes:
+        new_g.nodes[n]['type'] = node_type_list.index(new_g.nodes[n]['type'])
+    for e in new_g.edges:
+        new_g.edges[e]['type'] = edge_type_list.index(new_g.edges[e]['type'])
+    with open('{}.json'.format(name), 'w', encoding='utf-8') as f:
+        json.dump(nx.node_link_data(new_g), f)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Convert CamFlow JSON to Unicorn edgelist')
+    args = parser.parse_args()
+    args.stats = False
+    args.verbose = False
+    args.jiffies = False
+    args.input = '../data/wget/raw/'
+    args.output = '../data/wget/processed/'
+    args.final_output = '../data/wget/final/'
+    args.noencode = False
+    if not os.path.exists(args.input):
+        os.mkdir(args.input)
+    if not os.path.exists(args.output):
+        os.mkdir(args.output)
+    if not os.path.exists(args.final_output):
+        os.mkdir(args.final_output)
+    CONSOLE_ARGUMENTS = args
+
+    if args.verbose:
+        logging.basicConfig(filename=args.log, level=logging.DEBUG)
+    cnt = 0
+    for fname in os.listdir(args.input):
+        cnt += 1
+        node_map = dict()
+        parse_all_nodes(args.input + '/{}'.format(fname), node_map)
+        total_edges = parse_all_edges(args.input + '/{}'.format(fname), args.output + '/{}.log'.format(cnt), node_map,
+                                      args.noencode)
+        if args.stats:
+            total_nodes = len(node_map)
+            stats = open(args.stats_file + '/{}.log'.format(cnt), "a+")
+            stats.write("{},{},{}\n".format(args.input + '/{}'.format(fname), total_nodes, total_edges))
+
+    bidirection = False
+    threshold = 10000000 # infinity
+    interaction_dict = []
+    graph_cnt = 0
+    result_graphs = []
+    input = args.output
+    base = args.final_output
+
+    line_cnt = 0
+    for i in tqdm(range(cnt)):
+        single_cnt, result_graph = process_graph('{}{}.log'.format(input, i + 1), threshold)
+        format_graph(result_graph, '{}{}'.format(base, i))
+        line_cnt += single_cnt
+
+    print(line_cnt // 150)
+    print(len(node_type_list))
+    print(node_type_dict)
+    print(len(edge_type_list))
+    print(edge_type_dict)
+