diff --git a/README.md b/README.md
index 18d22331f41c4693d1558938b3e7471b4468afe2..22b75bfacc1affcdc8d9e94144606c1fc1eadc6e 100644
--- a/README.md
+++ b/README.md
@@ -1,18 +1,15 @@
-# FEDHE-Graph 
+# CONTINUUM-FEDHE-Graph 
 
-Welcome to the official repository housing the FEDHE-Graph implementation for Magic! This repository provides you with the necessary tools and resources to leverage federated learning techniques within the context of Magic, a comprehensive framework for federated learning research.
+Welcome to the official repository housing the FEDHE-Graph implementation for our solution Continuum! This repository provides you with the necessary tools and resources to leverage federated learning techniques within the context of Continuum, a comprehensive framework for federated learning research.
 
-![architecture](https://gitlab.liris.cnrs.fr/gladis/graphfl/-/raw/main/assets/archiFedHe.png?ref_type=heads)
+![architecture(1)-1](https://github.com/kamelferrahi/MAGIC_FEDERATED_FedML/assets/72205931/f3e67d1f-2fa1-4800-81e6-7d9c5e509cf7)
+![image](https://github.com/kamelferrahi/Continuum_FL/assets/72205931/fb78accc-df2d-4368-a690-443aba85059a)
 
 
-
-Original project: https://github.com/FDUDSDE/MAGIC
-
 ## Environment Setup
 
-The command are used in an environnement that consist of Ubuntu 22.04 with miniconda installed
+The command are used in an environnement that consist of Windows 11 with anaconda installed
 
-Original project: https://github.com/FDUDSDE/MAGIC
 
 First create the conda environnement for fedml with MPI support 
 
@@ -24,13 +21,13 @@ conda install -c conda-forge mpi4py openmpi
 pip install "fedml[MPI]" 
 ```
 
-Clone the MAGIC FedML project onto your current folder 
+Clone the Continuum FedML project onto your current folder 
 
 ```
-git clone https://github.com/kamelferrahi/MAGIC_FEDERATED_FedML
+git clone https://github.com/kamelferrahi/[MAGIC_FEDERATED_FedML](https://github.com/kamelferrahi/Continuum_FL)
 ```
 
-Install the necessary packages for Magic to run
+Install the necessary packages for Continuum to run
 
 ```
 conda install -c conda-forge aiohttp=3.9.1 aiosignal=1.3.1 anyio=4.2.0 attrdict=2.0.1 attrs=23.2.0 blis=0.7.11 boto3=1.34.12 botocore=1.34.12 brotli=1.1.0 catalogue=2.0.10 certifi=2023.11.17 chardet=5.2.0 charset-normalizer=3.3.2 click=8.1.7 cloudpathlib=0.16.0 confection=0.1.4 contourpy=1.2.0 cycler=0.12.1 cymem=2.0.8 dgl=1.1.3 dill=0.3.7 fastapi=0.92.0 fedml=0.8.13.post2 filelock=3.13.1 fonttools=4.47.0 frozenlist=1.4.1 fsspec=2023.12.2 gensim=4.3.2 gevent=23.9.1 geventhttpclient=2.0.9 gitdb=4.0.11 GitPython=3.1.40 GPUtil=1.4.0 graphviz=0.8.4 greenlet=3.0.3 h11=0.14.0 h5py=3.10.0 httpcore=1.0.2 httpx=0.26.0 idna=3.6 Jinja2=3.1.2 jmespath=1.0.1 joblib=1.3.2 kiwisolver=1.4.5 langcodes=3.3.0 MarkupSafe=2.1.3 matplotlib=3.8.2 mpi4py=3.1.3 mpmath=1.3.0 multidict=6.0.4 multiprocess=0.70.15 murmurhash=1.0.10 networkx=2.8.8 ntplib=0.4.0 numpy=1.26.3 nvidia-cublas-cu12=12.1.3.1 nvidia-cuda-cupti-cu12=12.1.105 nvidia-cuda-nvrtc-cu12=12.1.105 nvidia-cuda-runtime-cu12=12.1.105 nvidia-cudnn-cu12=8.9.2.26 nvidia-cufft-cu12=11.0.2.54 nvidia-curand-cu12=10.3.2.106 nvidia-cusolver-cu12=11.4.5.107 nvidia-cusparse-cu12=12.1.0.106 nvidia-nccl-cu12=2.18.1 nvidia-nvjitlink-cu12=12.3.101 nvidia-nvtx-cu12=12.1.105 onnx=1.15.0 packaging=23.2 paho-mqtt=1.6.1 pandas=2.1.4 pathtools=0.1.2 pillow=10.2.0 preshed=3.0.9 prettytable=3.9.0 promise=2.3 protobuf=3.20.3 psutil=5.9.7 py-machineid=0.4.6 pydantic=1.10.13 pyparsing=3.1.1 python-dateutil=2.8.2 python-rapidjson=1.14 pytz=2023.3.post1 PyYAML=6.0.1 redis=5.0.1 requests=2.31.0 s3transfer=0.10.0 scikit-learn=1.3.2 scipy=1.11.4 sentry-sdk=1.39.1 setproctitle=1.3.3 shortuuid=1.0.11 six=1.16.0 smart-open=6.3.0 smmap=5.0.1 sniffio=1.3.0 spacy=3.7.2 spacy-legacy=3.0.12 spacy-loggers=1.0.5 SQLAlchemy=2.0.25 srsly=2.4.8 starlette=0.25.0 sympy=1.12 thinc=8.2.2 threadpoolctl=3.2.0 torch=2.1.2 torch-cluster=1.6.3 torch-scatter=2.1.2 torch-sparse=0.6.18 torch-spline-conv=1.2.2 torch_geometric=2.4.0 torchvision=0.16.2 tqdm=4.66.1 triton=2.1.0 tritonclient=2.41.0 typer=0.9.0 typing_extensions=4.9.0 tzdata=2023.4 tzlocal=5.2 urllib3=2.0.7 uvicorn=0.25.0 wandb=0.13.2 wasabi=1.1.2 wcwidth=0.2.12 weasel=0.3.4 websocket-client=1.7.0 wget=3.2 yarl=1.9.4 zope.event=5.0 zope.interface=6.1
@@ -57,7 +54,7 @@ train_args:
 The algorithm tested are `FedAvg`, `FedProx` and `FedOpt`
 
 ## Datasets
-The experiments utilize datasets similar to those in the original Magic project. To change datasets, edit the `fedml_config.yaml` file:
+The experiments utilize datasets similar to those in the original Continuum project. To change datasets, edit the `fedml_config.yaml` file:
   ```
 data_args:
   dataset: "wget"
diff --git a/a.yaml b/a.yaml
deleted file mode 100644
index 9f8ce6e7fa889de842853b12c3a723d56ab60899..0000000000000000000000000000000000000000
--- a/a.yaml
+++ /dev/null
@@ -1,83 +0,0 @@
-common_args:
-  training_type: "cross_silo"
-  scenario: "horizontal"
-  using_mlops: false
-  config_version: release
-  name: "exp"
-  project: "runs/train"
-  exist_ok: false
-  random_seed: 0
-  
-  common_args:
-  training_type: "simulation"
-  random_seed: 0
-comm_args:
-  backend: "MPI"
-  is_mobile: 0
-
-
-data_args:
-  dataset: "clintox"
-  data_cache_dir: ~/fedgraphnn_data/
-  part_file:  ~/fedgraphnn_data/partition
-  partition_method: "hetero"
-  partition_alpha: 0.5
-
-model_args:
-  model: "graphsage"
-  hidden_size: 32
-  node_embedding_dim: 32
-  graph_embedding_dim: 64
-  readout_hidden_dim: 64
-  alpha: 0.2
-  num_heads: 2
-  dropout: 0.3
-  normalize_features: False
-  normalize_adjacency: False
-  sparse_adjacency: False
-  model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
-  global_model_file_path: "./model_file_cache/global_model.pt"
-
-environment_args:
-  bootstrap: config/bootstrap.sh
-
-train_args:
-  federated_optimizer: "FedAvg"
-  client_id_list: 
-  client_num_in_total: 3
-  client_num_per_round: 3
-  comm_round: 100
-  epochs: 5
-  batch_size: 64
-  client_optimizer: sgd
-  learning_rate: 0.03
-  weight_decay: 0.001
-  metric: "prc-auc"
-  server_optimizer: sgd
-  lr: 0.001
-  server_lr: 0.001
-  wd: 0.001
-  ci: 0
-  server_momentum: 0.9
-
-validation_args:
-  frequency_of_the_test: 1
-
-device_args:
-  worker_num: 3
-  using_gpu: false
-  gpu_mapping_file: config/gpu_mapping.yaml
-  gpu_mapping_key: mapping_fedgraphnn_sp
-
-comm_args:
-  backend: "MQTT_S3"
-  mqtt_config_path: config/mqtt_config.yaml
-  s3_config_path: config/s3_config.yaml
-
-
-tracking_args:
-   # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
-  enable_wandb: false
-  wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
-  wandb_project: fedml
-  wandb_name: fedml_torch_moleculenet
diff --git a/a.yml b/a.yml
deleted file mode 100644
index adfbd8733fcee48a46e87a24cc48cbb1a0241f87..0000000000000000000000000000000000000000
--- a/a.yml
+++ /dev/null
@@ -1,68 +0,0 @@
-common_args:
-  training_type: "simulation"
-  random_seed: 0
-
-data_args:
-  dataset: "clintox"
-  data_cache_dir: ~/fedgraphnn_data/
-  part_file:  ~/fedgraphnn_data/partition
-  partition_method: "hetero"
-  partition_alpha: 0.5
-
-model_args:
-  model: "graphsage"
-  hidden_size: 32
-  node_embedding_dim: 32
-  graph_embedding_dim: 64
-  readout_hidden_dim: 64
-  alpha: 0.2
-  num_heads: 2
-  dropout: 0.3
-  normalize_features: False
-  normalize_adjacency: False
-  sparse_adjacency: False
-  model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
-  global_model_file_path: "./model_file_cache/global_model.pt"
-
-environment_args:
-  bootstrap: config/bootstrap.sh
-
-train_args:
-  federated_optimizer: "FedAvg"
-  client_id_list: 
-  client_num_in_total: 3
-  client_num_per_round: 3
-  comm_round: 100
-  epochs: 5
-  batch_size: 64
-  client_optimizer: sgd
-  learning_rate: 0.03
-  weight_decay: 0.001
-  metric: "prc-auc"
-  server_optimizer: sgd
-  lr: 0.001
-  server_lr: 0.001
-  wd: 0.001
-  ci: 0
-  server_momentum: 0.9
-
-validation_args:
-  frequency_of_the_test: 1
-
-device_args:
-  worker_num: 2
-  using_gpu: false
-  gpu_mapping_file: config/gpu_mapping.yaml
-  gpu_mapping_key: mapping_fedgraphnn_sp
-
-comm_args:
-  backend: "MPI"
-  is_mobile: 0
-
-
-tracking_args:
-   # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
-  enable_wandb: false
-  wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
-  wandb_project: fedml
-  wandb_name: fedml_torch_moleculenet
diff --git a/assets/.gitkeep b/assets/.gitkeep
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/assets/archiFedHe.png b/assets/archiFedHe.png
deleted file mode 100644
index 8495f1456bb9251d2022b52c605efaddbebd447b..0000000000000000000000000000000000000000
Binary files a/assets/archiFedHe.png and /dev/null differ
diff --git a/checkpoints - Copie/checkpoint-SC2.pt b/checkpoints - Copie/checkpoint-SC2.pt
new file mode 100644
index 0000000000000000000000000000000000000000..963a001f792613ee2fb2f5f22cb25b599a8dedbb
Binary files /dev/null and b/checkpoints - Copie/checkpoint-SC2.pt differ
diff --git a/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt b/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..20073e1503d9f57b75c278bdb89c757d3aaccb30
Binary files /dev/null and b/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt differ
diff --git a/checkpoints - Copie/checkpoint-cadets-e3.pt b/checkpoints - Copie/checkpoint-cadets-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..ea7ee25689066513d46aea55e07f71df859ec8c7
Binary files /dev/null and b/checkpoints - Copie/checkpoint-cadets-e3.pt differ
diff --git a/checkpoints - Copie/checkpoint-clearscope-e3.pt b/checkpoints - Copie/checkpoint-clearscope-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5edaf9d73bdde92067b7e2f6f0129a4673feb945
Binary files /dev/null and b/checkpoints - Copie/checkpoint-clearscope-e3.pt differ
diff --git a/checkpoints - Copie/checkpoint-streamspot.pt b/checkpoints - Copie/checkpoint-streamspot.pt
new file mode 100644
index 0000000000000000000000000000000000000000..ec1ed84e3dc16589af58039e87a34f21cc579cd5
Binary files /dev/null and b/checkpoints - Copie/checkpoint-streamspot.pt differ
diff --git a/checkpoints - Copie/checkpoint-theia-e3.pt b/checkpoints - Copie/checkpoint-theia-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..1513228c205f790835edb723dffe967eec960732
Binary files /dev/null and b/checkpoints - Copie/checkpoint-theia-e3.pt differ
diff --git a/checkpoints - Copie/checkpoint-trace-e3.pt b/checkpoints - Copie/checkpoint-trace-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..755598d71f01bee477182c9fe454b4646ac7ffcd
Binary files /dev/null and b/checkpoints - Copie/checkpoint-trace-e3.pt differ
diff --git a/checkpoints - Copie/checkpoint-wget-long.pt b/checkpoints - Copie/checkpoint-wget-long.pt
new file mode 100644
index 0000000000000000000000000000000000000000..f5c8f96c5be068033746e3b9271b983c438bee54
Binary files /dev/null and b/checkpoints - Copie/checkpoint-wget-long.pt differ
diff --git a/checkpoints - Copie/checkpoint-wget.pt b/checkpoints - Copie/checkpoint-wget.pt
new file mode 100644
index 0000000000000000000000000000000000000000..1d5f6c0071b7c0d0e136aa6112ade76c083984fc
Binary files /dev/null and b/checkpoints - Copie/checkpoint-wget.pt differ
diff --git a/checkpoints/checkpoint-SC2.pt b/checkpoints/checkpoint-SC2.pt
new file mode 100644
index 0000000000000000000000000000000000000000..aec48a89437ad5ef4c3199b0547a3efc9365115f
Binary files /dev/null and b/checkpoints/checkpoint-SC2.pt differ
diff --git a/checkpoints/checkpoint-Unicorn-Cadets.pt b/checkpoints/checkpoint-Unicorn-Cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..90d5dcfbb036da54972648625073287bad1b6986
Binary files /dev/null and b/checkpoints/checkpoint-Unicorn-Cadets.pt differ
diff --git a/checkpoints/checkpoint-cadets-e3 - Copie (2).pt b/checkpoints/checkpoint-cadets-e3 - Copie (2).pt
new file mode 100644
index 0000000000000000000000000000000000000000..5f62ae9abed4280d950ebf431a30b2f883b386a2
Binary files /dev/null and b/checkpoints/checkpoint-cadets-e3 - Copie (2).pt differ
diff --git a/checkpoints/checkpoint-cadets-e3.pt b/checkpoints/checkpoint-cadets-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..157f6b404a8228c9da2483564ab892e0fd4f8a88
Binary files /dev/null and b/checkpoints/checkpoint-cadets-e3.pt differ
diff --git a/checkpoints/checkpoint-clearscope-e3.pt b/checkpoints/checkpoint-clearscope-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..0ebdb975ec2e7455daed67ab002f1b39a9984c1b
Binary files /dev/null and b/checkpoints/checkpoint-clearscope-e3.pt differ
diff --git a/checkpoints/checkpoint-streamspot.pt b/checkpoints/checkpoint-streamspot.pt
new file mode 100644
index 0000000000000000000000000000000000000000..0065a7b8b52665153af952617ab45cde61fe14eb
Binary files /dev/null and b/checkpoints/checkpoint-streamspot.pt differ
diff --git a/checkpoints/checkpoint-theia-e3.pt b/checkpoints/checkpoint-theia-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..1f0fcc66e18b3e6d534f99e090e5387807451a64
Binary files /dev/null and b/checkpoints/checkpoint-theia-e3.pt differ
diff --git a/checkpoints/checkpoint-trace-e3.pt b/checkpoints/checkpoint-trace-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..f050065e56c009605562ca6e980652b6c2a33938
Binary files /dev/null and b/checkpoints/checkpoint-trace-e3.pt differ
diff --git a/checkpoints/checkpoint-wget-long.pt b/checkpoints/checkpoint-wget-long.pt
new file mode 100644
index 0000000000000000000000000000000000000000..040897b99810c72f6bb6d2314aab468b3972de3e
Binary files /dev/null and b/checkpoints/checkpoint-wget-long.pt differ
diff --git a/checkpoints/checkpoint-wget.pt b/checkpoints/checkpoint-wget.pt
new file mode 100644
index 0000000000000000000000000000000000000000..8404279998dc213fba6b6f626d55c37ccb7c8ca9
Binary files /dev/null and b/checkpoints/checkpoint-wget.pt differ
diff --git a/distance_save_cadets.pkl b/distance_save_cadets.pkl
deleted file mode 100644
index 1f2f428d6e2e6ffb8afbfd0d0669c46cab8718c6..0000000000000000000000000000000000000000
Binary files a/distance_save_cadets.pkl and /dev/null differ
diff --git a/eval_result/.gitkeep b/eval_result/.gitkeep
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/eval_result/distance_save_cadets-e3 - Copie.pkl b/eval_result/distance_save_cadets-e3 - Copie.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..7f737190e44abb34ebea3e6763c64b9f1ae3a89a
Binary files /dev/null and b/eval_result/distance_save_cadets-e3 - Copie.pkl differ
diff --git a/eval_result/distance_save_cadets-e3.pkl b/eval_result/distance_save_cadets-e3.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..10c18f8c50bfce89a1afd0811cc24ce8e0595786
Binary files /dev/null and b/eval_result/distance_save_cadets-e3.pkl differ
diff --git a/eval_result/distance_save_cadets.pkl b/eval_result/distance_save_cadets.pkl
deleted file mode 100644
index 2bb4de8e00a64a22e53b1a5187821e6e53dfffcc..0000000000000000000000000000000000000000
Binary files a/eval_result/distance_save_cadets.pkl and /dev/null differ
diff --git a/eval_result/distance_save_theia-e3.pkl b/eval_result/distance_save_theia-e3.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..c5d31bab1acc6e943c746a154eec3a9370ccb644
Binary files /dev/null and b/eval_result/distance_save_theia-e3.pkl differ
diff --git a/eval_result/distance_save_trace-e3 - Copie.pkl b/eval_result/distance_save_trace-e3 - Copie.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..709b4f83bc0d8cd4d8be175fed1c54d366c9de16
Binary files /dev/null and b/eval_result/distance_save_trace-e3 - Copie.pkl differ
diff --git a/eval_result/distance_save_trace-e3.pkl b/eval_result/distance_save_trace-e3.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..709b4f83bc0d8cd4d8be175fed1c54d366c9de16
Binary files /dev/null and b/eval_result/distance_save_trace-e3.pkl differ
diff --git a/fedml_config.yaml b/fedml_config.yaml
index 66714a20c7a2b471676dac8be0801b3ae132e5cf..40cd616a925f3514dda97dd8ba1b20aaead81744 100644
--- a/fedml_config.yaml
+++ b/fedml_config.yaml
@@ -1,49 +1,53 @@
 common_args:
-  training_type: "simulation"
+  training_type: "cross_silo"
+  scenario: "horizontal"
+  using_mlops: false
+  config_version: release
+  name: "exp"
+  project: "runs/train"
+  exist_ok: false
   random_seed: 0
   
 data_args:
-  dataset: "wget"
-  data_cache_dir: ~/fedgraphnn_data/
-  part_file:  ~/fedgraphnn_data/partition
+  dataset: "trace-e3"
 
 model_args:
   model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
   global_model_file_path: "./model_file_cache/global_model.pt"
 
-environment_args:
-  bootstrap: config/bootstrap.sh
 
 train_args:
   federated_optimizer: "FedAvg"
   client_id_list: 
-  client_num_in_total: 4
-  client_num_per_round: 4
-  comm_round: 100
-  lr: 0.001
-  server_lr: 0.001
-  wd: 0.001
-  ci: 0
-  server_momentum: 0.9
+  client_num_in_total: 2
+  client_num_per_round: 2
+  comm_round: 1
+  snapshot: 1
 
 validation_args:
   frequency_of_the_test: 1
 
 device_args:
-  worker_num: 4
-  using_gpu: false
-  gpu_mapping_file: config/gpu_mapping.yaml
-  gpu_mapping_key: mapping_fedgraphnn_sp
+  worker_num: 2
+  using_gpu: true
+  gpu_mapping_file: gpu_mapping.yaml
+  gpu_mapping_key: mapping_config
 
 
 comm_args:
-  backend: "MPI"
-  is_mobile: 0
-
+  backend: "MQTT_S3"
+  mqtt_config_path: config/mqtt_config.yaml
 
 tracking_args:
    # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
   enable_wandb: false
   wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
   wandb_project: fedml
-  wandb_name: fedml_torch_moleculenet
+  wandb_name: fedml_torch
+
+# fhe_args:
+# #   enable_fhe: true
+ #  scheme: ckks
+#  batch_size: 8192
+#   scaling_factor: 52
+#   file_loc: "resources/cryptoparams/"
\ No newline at end of file
diff --git a/gpu_mapping.yaml b/gpu_mapping.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..90b13bfe6abfa84b56845ac220a5c67ced781cbe
--- /dev/null
+++ b/gpu_mapping.yaml
@@ -0,0 +1,2 @@
+mapping_config:
+    host1: [3]
\ No newline at end of file
diff --git a/main.py b/main.py
index a1a716e17f69f73d33fa62ef5b3da1e85b90b715..7d7430d7fb88d4259d8096266633bb5dbd94bc8f 100644
--- a/main.py
+++ b/main.py
@@ -1,20 +1,18 @@
 import logging
 
 import fedml
-from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split
+from utils.dataloader import load_partition_data, load_data, load_metadata, darpa_split
 from fedml import FedMLRunner
 from trainer.magic_trainer import MagicTrainer
 from trainer.magic_aggregator import MagicWgetAggregator
-from model.autoencoder import build_model
+from model.model import STGNN_AutoEncoder
 from utils.config import build_args
 from trainer.magic_trainer import MagicTrainer
 from trainer.magic_aggregator import MagicWgetAggregator
-from trainer.single_trainer import train_single
-from utils.loaddata import load_batch_level_dataset, load_metadata
 
 
 
-def generate_dataset(name, number):
+def generate_dataset(name, number, nsnapshot):
     (
         train_data_num,
         val_data_num,
@@ -26,7 +24,7 @@ def generate_dataset(name, number):
         train_data_local_dict,
         val_data_local_dict,
         test_data_local_dict,
-    ) = load_partition_data(None, number, name) 
+    ) = load_partition_data(number, name, nsnapshot) 
     dataset = [
         train_data_num,
         test_data_num,
@@ -38,9 +36,9 @@ def generate_dataset(name, number):
         len(train_data_global),
     ]
     
-    if (name == "wget" or name == "streamspot"):
+    if (name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']):
         
-        return dataset, load_batch_level_dataset(name)
+        return dataset, load_data(name)
     else:
         return dataset, load_metadata(name) 
            
@@ -49,39 +47,47 @@ if __name__ == "__main__":
     # init FedML framework
     args = fedml.init()
     # init device
+
     device = fedml.device.get_device(args)
-    name = args.dataset
+    dataset_name = args.dataset
     number = args.client_num_in_total
-    
-    dataset, metadata = generate_dataset(name, number)
+    nsnapshot = args.snapshot
+    dataset, metadata = generate_dataset(dataset_name, number, nsnapshot)
     main_args = build_args()
-    if (name == "wget"):
-        main_args["num_hidden"] = 256
-        main_args["max_epoch"] = 2
-        main_args["num_layers"] = 4
+    if (dataset_name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']):
+            
+        main_args.max_epoch = 6
+        out_dim = 64
+        if (dataset_name == 'SC2'):
+            gnn_layer = 3
+        else:
+            gnn_layer = 5
         n_node_feat = metadata['n_feat']
         n_edge_feat = metadata['e_feat']
-        main_args["n_dim"] = n_node_feat
-        main_args["e_dim"] = n_edge_feat
-    elif (name == "streamspot"):
-        main_args["num_hidden"] = 256
-        main_args["max_epoch"] = 5
-        main_args["num_layers"] = 4
-        n_node_feat = metadata['n_feat']
-        n_edge_feat = metadata['e_feat']
-        main_args["n_dim"] = n_node_feat
-        main_args["e_dim"] = n_edge_feat
-    else:
-        main_args["num_hidden"] = 64
-        main_args["max_epoch"] = 50
-        main_args["num_layers"] = 3
-        main_args["n_dim"] = metadata["node_feature_dim"]
-        main_args["e_dim"] = metadata["edge_feature_dim"]
-    
-    model = build_model(main_args)
+        use_all_hidden =  True
+        main_args.n_dim = n_node_feat
+        main_args.e_dim = n_edge_feat
+    else: 
+        use_all_hidden =  False
+        n_node_feat = metadata['node_feature_dim']
+        n_edge_feat = metadata['edge_feature_dim']
+            #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148]
+        main_args.n_dim = n_node_feat
+        main_args.e_dim = n_edge_feat
+        main_args.max_epoch = 50
+        out_dim = 64
+
+        if (dataset_name == 'cadets-e3'):
+            gnn_layer = 4
+        else:
+            gnn_layer = 3
+
+        
+
+    model = STGNN_AutoEncoder(main_args.n_dim, main_args.e_dim, out_dim, out_dim, gnn_layer, 4, device,  nsnapshot, 'prelu', 0.1, main_args.negative_slope, True, 'BatchNorm', main_args.pooling, alpha_l=main_args.alpha_l, use_all_hidden=use_all_hidden).to(device)  # Move model to GPU
     #train_single(main_args, model, data)
-    trainer = MagicTrainer(model, args, name)
-    aggregator = MagicWgetAggregator(model, args, name)
+    trainer = MagicTrainer(model, args, dataset_name)
+    aggregator = MagicWgetAggregator(model, args, dataset_name)
     fedml_runner = FedMLRunner(args, device, dataset, model, trainer, aggregator)
     fedml_runner.run()
     # start training
diff --git a/model/__pycache__/autoencoder.cpython-311.pyc b/model/__pycache__/autoencoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3abc72625245a57fa06470e3659c594a0c033046
Binary files /dev/null and b/model/__pycache__/autoencoder.cpython-311.pyc differ
diff --git a/model/__pycache__/eval.cpython-310.pyc b/model/__pycache__/eval.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0dc58b1214be9d335558c42907cf1feb18b81969
Binary files /dev/null and b/model/__pycache__/eval.cpython-310.pyc differ
diff --git a/model/__pycache__/eval.cpython-311.pyc b/model/__pycache__/eval.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..faade8397a5b3c546b35f34aef088d8992486510
Binary files /dev/null and b/model/__pycache__/eval.cpython-311.pyc differ
diff --git a/model/__pycache__/gat.cpython-310.pyc b/model/__pycache__/gat.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54fa4abdd4680faa41fbf0e19697fab2f000402b
Binary files /dev/null and b/model/__pycache__/gat.cpython-310.pyc differ
diff --git a/model/__pycache__/gat.cpython-311.pyc b/model/__pycache__/gat.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94e6046f4f22c6a2f271c2ca3bded5d8f2417cfa
Binary files /dev/null and b/model/__pycache__/gat.cpython-311.pyc differ
diff --git a/model/__pycache__/loss_func.cpython-310.pyc b/model/__pycache__/loss_func.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79ac15810418709637606b15c5c43b4b96aec691
Binary files /dev/null and b/model/__pycache__/loss_func.cpython-310.pyc differ
diff --git a/model/__pycache__/loss_func.cpython-311.pyc b/model/__pycache__/loss_func.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40b3c9d58012a8600cda16f7d39ee5c5891fa862
Binary files /dev/null and b/model/__pycache__/loss_func.cpython-311.pyc differ
diff --git a/model/__pycache__/model.cpython-310.pyc b/model/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3984a3444fb13f56ea0c2629e2b0883298bf8f4d
Binary files /dev/null and b/model/__pycache__/model.cpython-310.pyc differ
diff --git a/model/__pycache__/model.cpython-311.pyc b/model/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9a27f13ee6839a64257d3125b777544a866f329
Binary files /dev/null and b/model/__pycache__/model.cpython-311.pyc differ
diff --git a/model/__pycache__/rnn.cpython-310.pyc b/model/__pycache__/rnn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d02d9b442912cda476cacbc18673eebe9e000da7
Binary files /dev/null and b/model/__pycache__/rnn.cpython-310.pyc differ
diff --git a/model/__pycache__/rnn.cpython-311.pyc b/model/__pycache__/rnn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d71283ae5cbfc01ec9eb9f1448d8e8b292d1f08
Binary files /dev/null and b/model/__pycache__/rnn.cpython-311.pyc differ
diff --git a/model/__pycache__/test.cpython-310.pyc b/model/__pycache__/test.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d85be9aa8853ef325090a1b1cd0f366d3224792
Binary files /dev/null and b/model/__pycache__/test.cpython-310.pyc differ
diff --git a/model/__pycache__/test.cpython-311.pyc b/model/__pycache__/test.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d7beba377fb059135e3135b292a910477636c7c
Binary files /dev/null and b/model/__pycache__/test.cpython-311.pyc differ
diff --git a/model/__pycache__/train.cpython-311.pyc b/model/__pycache__/train.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b788a354802ea312364f91691db5b00e502922a6
Binary files /dev/null and b/model/__pycache__/train.cpython-311.pyc differ
diff --git a/model/__pycache__/train_entity.cpython-310.pyc b/model/__pycache__/train_entity.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cbe4cebe4c9d3a09cdc4a8b78f976f98a23baef1
Binary files /dev/null and b/model/__pycache__/train_entity.cpython-310.pyc differ
diff --git a/model/__pycache__/train_entity.cpython-311.pyc b/model/__pycache__/train_entity.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf1b64f3a76e749403864eb4a4a10154a8003a84
Binary files /dev/null and b/model/__pycache__/train_entity.cpython-311.pyc differ
diff --git a/model/__pycache__/train_graph.cpython-310.pyc b/model/__pycache__/train_graph.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf931f350c1de50fb725f95c7ebc73f1c6facc66
Binary files /dev/null and b/model/__pycache__/train_graph.cpython-310.pyc differ
diff --git a/model/__pycache__/train_graph.cpython-311.pyc b/model/__pycache__/train_graph.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d98093beaf11c1ffeb71b9abc556e33503934038
Binary files /dev/null and b/model/__pycache__/train_graph.cpython-311.pyc differ
diff --git a/model/eval.py b/model/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..7446db31e19a2e144eb061e0526cb44e23798935
--- /dev/null
+++ b/model/eval.py
@@ -0,0 +1,270 @@
+import os
+import random
+import time
+import pickle as pkl
+import torch
+import numpy as np
+from sklearn.metrics import roc_auc_score, precision_recall_curve
+from sklearn.neighbors import NearestNeighbors
+from utils.utils import set_random_seed
+from utils.dataloader import load_graph, load_data
+import pickle as pkl
+
+def batch_level_evaluation(model, pooler, device, method, dataset, n_dim=0, e_dim=0):
+    print('Start Evaluation')
+
+    model.eval()
+    x_list = []
+    y_list = []
+    data = load_data(dataset)
+    full = data['full_index']
+    labels = data['labels']
+    with torch.no_grad():
+        for i in full:
+            #break
+            g = load_graph(i, dataset, device)
+            label = labels[i]
+            out = model.embed(g)
+            if dataset != 'wget':
+                out = pooler(g[-1], out).cpu().numpy()
+            else:
+                out = pooler(g[-1], out, [1]).cpu().numpy()
+            y_list.append(label)
+            x_list.append(out)
+    
+    #pkl.dump(x_list,open('xlist.pkl','wb') )
+    #pkl.dump(y_list,open('ylist.pkl','wb') )
+    #x_list = pkl.load(open('xlist.pkl','rb'))
+    #y_list = pkl.load(open('ylist.pkl','rb'))
+    x = np.concatenate(x_list, axis=0)
+    y = np.array(y_list)
+    if 'knn' in method:
+        test_auc, test_std = evaluate_batch_level_using_knn(-1, dataset, x, y)
+    else:
+        raise NotImplementedError
+    return test_auc, test_std
+
+
+def evaluate_batch_level_using_knn(repeat, dataset, embeddings, labels):
+    x, y = embeddings, labels
+    if dataset == 'streamspot':
+        train_count = 400
+    elif (dataset == 'Unicorn-Cadets' or dataset == 'wget-long'):
+        train_count = 70
+    elif (dataset == 'wget' or dataset == 'SC2'): 
+        train_count = 100
+    else:
+        train_count = 30
+
+    if (dataset =='SC2'):
+        n_neighbors = min(int(train_count * 0.02), 10)
+    else:
+        n_neighbors = 100
+
+    benign_idx = np.where(y == 0)[0]
+    attack_idx = np.where(y == 1)[0]
+    if repeat != -1:
+        prec_list = []
+        rec_list = []
+        f1_list = []
+        tp_list = []
+        fp_list = []
+        tn_list = []
+        fn_list = []
+        auc_list = []
+        for s in range(repeat):
+            set_random_seed(s)
+            np.random.shuffle(benign_idx)
+            np.random.shuffle(attack_idx)
+            x_train = x[benign_idx[:train_count]]
+            x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0)
+            y_test = np.concatenate([y[benign_idx[train_count:]], y[attack_idx]], axis=0)
+            x_train_mean = x_train.mean(axis=0)
+            x_train_std = x_train.std(axis=0)
+            x_train = (x_train - x_train_mean) / x_train_std
+            x_test = (x_test - x_train_mean) / x_train_std
+
+            nbrs = NearestNeighbors(n_neighbors=n_neighbors)
+            nbrs.fit(x_train)
+            distances, indexes = nbrs.kneighbors(x_train, n_neighbors=n_neighbors)
+            mean_distance = distances.mean() * n_neighbors / (n_neighbors - 1)
+            distances, indexes = nbrs.kneighbors(x_test, n_neighbors=n_neighbors)
+
+            score = distances.mean(axis=1) / mean_distance
+
+            auc = roc_auc_score(y_test, score)
+            prec, rec, threshold = precision_recall_curve(y_test, score)
+            f1 = 2 * prec * rec / (rec + prec + 1e-9)
+            max_f1_idx = np.argmax(f1)
+            best_thres = threshold[max_f1_idx]
+            prec_list.append(prec[max_f1_idx])
+            rec_list.append(rec[max_f1_idx])
+            f1_list.append(f1[max_f1_idx])
+
+            tn = 0
+            fn = 0
+            tp = 0
+            fp = 0
+            for i in range(len(y_test)):
+                if y_test[i] == 1.0 and score[i] >= best_thres:
+                    tp += 1
+                if y_test[i] == 1.0 and score[i] < best_thres:
+                    fn += 1
+                if y_test[i] == 0.0 and score[i] < best_thres:
+                    tn += 1
+                if y_test[i] == 0.0 and score[i] >= best_thres:
+                    fp += 1
+            tp_list.append(tp)
+            fp_list.append(fp)
+            fn_list.append(fn)
+            tn_list.append(tn)
+            auc_list.append(auc)
+
+        print('AUC: {}+{}'.format(np.mean(auc_list), np.std(auc_list)))
+        print('F1: {}+{}'.format(np.mean(f1_list), np.std(f1_list)))
+        print('PRECISION: {}+{}'.format(np.mean(prec_list), np.std(prec_list)))
+        print('RECALL: {}+{}'.format(np.mean(rec_list), np.std(rec_list)))
+        print('TN: {}+{}'.format(np.mean(tn_list), np.std(tn_list)))
+        print('FN: {}+{}'.format(np.mean(fn_list), np.std(fn_list)))
+        print('TP: {}+{}'.format(np.mean(tp_list), np.std(tp_list)))
+        print('FP: {}+{}'.format(np.mean(fp_list), np.std(fp_list)))
+        return np.mean(auc_list), np.std(auc_list)
+    else:
+        set_random_seed(2022)
+        np.random.shuffle(benign_idx)
+        np.random.shuffle(attack_idx)
+        x_train = x[benign_idx[:train_count]]
+        #x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0)
+        x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0)
+        y_test = np.concatenate([y[benign_idx[train_count:]], y[attack_idx]], axis=0)
+        x_train_mean = x_train.mean(axis=0)
+        x_train_std = x_train.std(axis=0)
+        x_train = (x_train - x_train_mean) / x_train_std
+        x_test = (x_test - x_train_mean) / x_train_std
+        f1_max = 0
+        for n_neighbors in range(1, train_count):
+            nbrs = NearestNeighbors(n_neighbors=n_neighbors)
+            nbrs.fit(x_train)
+            distances, indexes = nbrs.kneighbors(x_train, n_neighbors=n_neighbors)
+            mean_distance = distances.mean() * n_neighbors / (n_neighbors - 1)
+            #mean_distance = 0.1
+            distances, indexes = nbrs.kneighbors(x_test, n_neighbors=n_neighbors)
+
+            score = distances.mean(axis=1) / mean_distance
+            auc = roc_auc_score(y_test, score)
+            prec, rec, threshold = precision_recall_curve(y_test, score)
+            f1 = 2 * prec * rec / (rec + prec + 1e-9)
+            best_idx = np.argmax(f1)
+            best_thres = threshold[best_idx]
+
+            tn = 0
+            fn = 0
+            tp = 0
+            fp = 0
+
+            for i in range(len(y_test)):
+                if y_test[i] == 1.0 and score[i] >= best_thres:
+                    tp += 1
+                if y_test[i] == 1.0 and score[i] < best_thres:
+                    fn += 1
+                if y_test[i] == 0.0 and score[i] < best_thres:
+
+                    tn += 1
+                if y_test[i] == 0.0 and score[i] >= best_thres:
+                    fp += 1
+            
+            if (f1[best_idx]> f1_max):
+                f1_max = f1[best_idx]
+                auc_max = auc
+                prec_max =  prec[best_idx]
+                rec_max = rec[best_idx]
+                tn_max = tn 
+                fn_max = fn
+                tp_max = tp
+                fp_max = fp
+                best_n = n_neighbors
+
+        print('AUC: {}'.format(auc_max))
+        print('F1: {}'.format(f1_max))
+        print('PRECISION: {}'.format(prec_max))
+        print('RECALL: {}'.format(rec_max))
+        print('TN: {}'.format(tn_max))
+        print('FN: {}'.format(fn_max))
+        print('TP: {}'.format(tp_max))
+        print('FP: {}'.format(fp_max))
+        print(best_n)
+        return auc, 0.0
+
+
+def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test):
+    x_train_mean = x_train.mean(axis=0)
+    x_train_std = x_train.std(axis=0)
+    x_train = (x_train - x_train_mean) / x_train_std
+    x_test = (x_test - x_train_mean) / x_train_std
+
+    if dataset == 'cadets-e3':
+        n_neighbors = 200
+    else:
+        n_neighbors = 10
+
+    nbrs = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=-1)
+    nbrs.fit(x_train)
+
+    save_dict_path = './eval_result/distance_save_{}.pkl'.format(dataset)
+    if not os.path.exists(save_dict_path):
+        idx = list(range(x_train.shape[0]))
+        random.shuffle(idx)
+        distances, _ = nbrs.kneighbors(x_train[idx][:min(50000, x_train.shape[0])], n_neighbors=n_neighbors)
+        del x_train
+        mean_distance = distances.mean()
+        del distances
+        distances, _ = nbrs.kneighbors(x_test, n_neighbors=n_neighbors)
+        save_dict = [mean_distance, distances.mean(axis=1)]
+        distances = distances.mean(axis=1)
+        with open(save_dict_path, 'wb') as f:
+            pkl.dump(save_dict, f)
+    else:
+        with open(save_dict_path, 'rb') as f:
+            mean_distance, distances = pkl.load(f)
+    score = distances / mean_distance
+    del distances
+    auc = roc_auc_score(y_test, score)
+    prec, rec, threshold = precision_recall_curve(y_test, score)
+    f1 = 2 * prec * rec / (rec + prec + 1e-9)
+    best_idx = np.argmax(f1)
+
+    for i in range(len(f1)):
+        # To repeat peak performance
+        if dataset == 'trace' and rec[i] < 0.99979:
+            best_idx = i - 1
+            break
+        if dataset == 'theia' and rec[i] < 0.99996:
+            best_idx = i - 1
+            break
+        if dataset == 'cadets-e3' and rec[i] < 0.9976:
+            best_idx = i - 1
+            break
+    best_thres = threshold[best_idx]
+
+    tn = 0
+    fn = 0
+    tp = 0
+    fp = 0
+    for i in range(len(y_test)):
+        if y_test[i] == 1.0 and score[i] >= best_thres:
+            tp += 1
+        if y_test[i] == 1.0 and score[i] < best_thres:
+            fn += 1
+        if y_test[i] == 0.0 and score[i] < best_thres:
+            tn += 1
+        if y_test[i] == 0.0 and score[i] >= best_thres:
+            fp += 1
+    print('AUC: {}'.format(auc))
+    print('F1: {}'.format(f1[best_idx]))
+    print('PRECISION: {}'.format(prec[best_idx]))
+    print('RECALL: {}'.format(rec[best_idx]))
+    print('TN: {}'.format(tn))
+    print('FN: {}'.format(fn))
+    print('TP: {}'.format(tp))
+    print('FP: {}'.format(fp))
+    return auc, 0.0, None, None
\ No newline at end of file
diff --git a/model/gat.py b/model/gat.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b2d2daeca4468f356274049f7f4454496a8ac06
--- /dev/null
+++ b/model/gat.py
@@ -0,0 +1,232 @@
+import torch
+import torch.nn as nn
+from dgl.ops import edge_softmax
+import dgl.function as fn
+from dgl.utils import expand_as_pair
+from utils.utils import create_activation
+
+
+class GAT(nn.Module):
+    def __init__(self,
+                 n_dim,
+                 e_dim,
+                 hidden_dim,
+                 out_dim,
+                 n_layers,
+                 n_heads,
+                 n_heads_out,
+                 activation,
+                 feat_drop,
+                 attn_drop,
+                 negative_slope,
+                 residual,
+                 norm,
+                 concat_out=False,
+                 encoding=False
+                 ):
+        super(GAT, self).__init__()
+        self.out_dim = out_dim
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.gats = nn.ModuleList()
+        self.concat_out = concat_out
+
+        last_activation = create_activation(activation) if encoding else None
+        last_residual = (encoding and residual)
+        last_norm = norm if encoding else None
+        if self.n_layers == 1:
+            self.gats.append(GATConv(
+                n_dim, e_dim, out_dim, n_heads_out, feat_drop, attn_drop, negative_slope,
+                last_residual, norm=last_norm, concat_out=self.concat_out
+            ))
+        else:
+            self.gats.append(GATConv(
+                n_dim, e_dim, hidden_dim, n_heads, feat_drop, attn_drop, negative_slope,
+                residual, create_activation(activation),
+                norm=norm, concat_out=self.concat_out
+            ))
+            for _ in range(1, self.n_layers - 1):
+                self.gats.append(GATConv(
+                    hidden_dim * self.n_heads, e_dim, hidden_dim, n_heads,
+                    feat_drop, attn_drop, negative_slope,
+                    residual, create_activation(activation),
+                    norm=norm, concat_out=self.concat_out
+                ))
+            self.gats.append(GATConv(
+                hidden_dim * self.n_heads, e_dim, out_dim, n_heads_out,
+                feat_drop, attn_drop, negative_slope,
+                last_residual, last_activation, norm=last_norm, concat_out=self.concat_out
+            ))
+        self.head = nn.Identity()
+
+    def forward(self, g, input_feature, return_hidden=False):
+        h = input_feature
+        hidden_list = []
+        for layer in range(self.n_layers):
+            h = self.gats[layer](g, h)
+            hidden_list.append(h)
+        if return_hidden:
+            return self.head(h), hidden_list
+        else:
+            return self.head(h)
+
+    def reset_classifier(self, num_classes):
+        self.head = nn.Linear(self.num_heads * self.out_dim, num_classes)
+
+
+class GATConv(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 e_dim,
+                 out_dim,
+                 n_heads,
+                 feat_drop=0.0,
+                 attn_drop=0.0,
+                 negative_slope=0.2,
+                 residual=False,
+                 activation=None,
+                 allow_zero_in_degree=False,
+                 bias=True,
+                 norm=None,
+                 concat_out=True):
+        super(GATConv, self).__init__()
+        self.n_heads = n_heads
+        self.src_feat, self.dst_feat = expand_as_pair(in_dim)
+        self.edge_feat = e_dim
+        self.out_feat = out_dim
+        self.allow_zero_in_degree = allow_zero_in_degree
+        self.concat_out = concat_out
+
+        if isinstance(in_dim, tuple):
+            self.fc_node_embedding = nn.Linear(
+                self.src_feat, self.out_feat * self.n_heads, bias=False)
+            self.fc_src = nn.Linear(self.src_feat, self.out_feat * self.n_heads, bias=False)
+            self.fc_dst = nn.Linear(self.dst_feat, self.out_feat * self.n_heads, bias=False)
+        else:
+            self.fc_node_embedding = nn.Linear(
+                self.src_feat, self.out_feat * self.n_heads, bias=False)
+            self.fc = nn.Linear(self.src_feat, self.out_feat * self.n_heads, bias=False)
+        self.edge_fc = nn.Linear(self.edge_feat, self.out_feat * self.n_heads, bias=False)
+        self.attn_h = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat)))
+        self.attn_e = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat)))
+        self.attn_t = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat)))
+        self.feat_drop = nn.Dropout(feat_drop)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.leaky_relu = nn.LeakyReLU(negative_slope)
+        if bias:
+            self.bias = nn.Parameter(torch.FloatTensor(size=(1, self.n_heads, self.out_feat)))
+        else:
+            self.register_buffer('bias', None)
+        if residual:
+            if self.dst_feat != self.n_heads * self.out_feat:
+                self.res_fc = nn.Linear(
+                    self.dst_feat, self.n_heads * self.out_feat, bias=False)
+            else:
+                self.res_fc = nn.Identity()
+        else:
+            self.register_buffer('res_fc', None)
+        self.reset_parameters()
+        self.activation = activation
+        self.norm = norm
+        if norm is not None:
+            self.norm = norm(self.n_heads * self.out_feat)
+
+    def reset_parameters(self):
+        gain = nn.init.calculate_gain('relu')
+        nn.init.xavier_normal_(self.edge_fc.weight, gain=gain)
+        if hasattr(self, 'fc'):
+            nn.init.xavier_normal_(self.fc.weight, gain=gain)
+        else:
+            nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
+            nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
+        nn.init.xavier_normal_(self.attn_h, gain=gain)
+        nn.init.xavier_normal_(self.attn_e, gain=gain)
+        nn.init.xavier_normal_(self.attn_t, gain=gain)
+        if self.bias is not None:
+            nn.init.constant_(self.bias, 0)
+        if isinstance(self.res_fc, nn.Linear):
+            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
+
+    def set_allow_zero_in_degree(self, set_value):
+        self.allow_zero_in_degree = set_value
+
+    def forward(self, graph, feat, get_attention=False):
+        edge_feature = graph.edata['attr']
+        with graph.local_scope():
+            if isinstance(feat, tuple):
+                src_prefix_shape = feat[0].shape[:-1]
+                dst_prefix_shape = feat[1].shape[:-1]
+                h_src = self.feat_drop(feat[0])
+                h_dst = self.feat_drop(feat[1])
+                if not hasattr(self, 'fc_src'):
+                    feat_src = self.fc(h_src).view(
+                        *src_prefix_shape, self.n_heads, self.out_feat)
+                    feat_dst = self.fc(h_dst).view(
+                        *dst_prefix_shape, self.n_heads, self.out_feat)
+                else:
+                    feat_src = self.fc_src(h_src).view(
+                        *src_prefix_shape, self.n_heads, self.out_feat)
+                    feat_dst = self.fc_dst(h_dst).view(
+                        *dst_prefix_shape, self.n_heads, self.out_feat)
+            else:
+                src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
+                h_src = h_dst = self.feat_drop(feat)
+                feat_src = feat_dst = self.fc(h_src).view(
+                    *src_prefix_shape, self.n_heads, self.out_feat)
+                if graph.is_block:
+                    feat_dst = feat_src[:graph.number_of_dst_nodes()]
+                    h_dst = h_dst[:graph.number_of_dst_nodes()]
+                    dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
+            edge_prefix_shape = edge_feature.shape[:-1]
+            eh = (feat_src * self.attn_h).sum(-1).unsqueeze(-1)
+            et = (feat_dst * self.attn_t).sum(-1).unsqueeze(-1)
+
+            graph.srcdata.update({'hs': feat_src, 'eh': eh})
+            graph.dstdata.update({'et': et})
+
+            feat_edge = self.edge_fc(edge_feature).view(
+                *edge_prefix_shape, self.n_heads, self.out_feat)
+            ee = (feat_edge * self.attn_e).sum(-1).unsqueeze(-1)
+
+            graph.edata.update({'ee': ee})
+            graph.apply_edges(fn.u_add_e('eh', 'ee', 'ee'))
+            graph.apply_edges(fn.e_add_v('ee', 'et', 'e'))
+            """
+            graph.apply_edges(fn.u_add_v('eh', 'et', 'e'))
+            """
+            e = self.leaky_relu(graph.edata.pop('e'))
+            graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
+            # message passing
+
+            graph.update_all(fn.u_mul_e('hs', 'a', 'm'),
+                             fn.sum('m', 'hs'))
+
+            rst = graph.dstdata['hs'].view(-1, self.n_heads, self.out_feat)
+
+            if self.bias is not None:
+                rst = rst + self.bias.view(
+                    *((1,) * len(dst_prefix_shape)), self.n_heads, self.out_feat)
+
+            # residual
+
+            if self.res_fc is not None:
+                # Use -1 rather than self._num_heads to handle broadcasting
+                resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self.out_feat)
+                rst = rst + resval
+
+            if self.concat_out:
+                rst = rst.flatten(1)
+            else:
+                rst = torch.mean(rst, dim=1)
+
+            if self.norm is not None:
+                rst = self.norm(rst)
+
+                # activation
+            if self.activation:
+                rst = self.activation(rst)
+
+            if get_attention:
+                return rst, graph.edata['a']
+            else:
+                return rst
diff --git a/model/loss_func.py b/model/loss_func.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2eff3e79376bf6ec4157b3c56f3a6ebd50a3127
--- /dev/null
+++ b/model/loss_func.py
@@ -0,0 +1,9 @@
+import torch.nn.functional as F
+
+
+def sce_loss(x, y, alpha=3):
+    x = F.normalize(x, p=2, dim=-1)
+    y = F.normalize(y, p=2, dim=-1)
+    loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
+    loss = loss.mean()
+    return loss
diff --git a/model/model - Copie.py b/model/model - Copie.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bd76a77dea36485ce923267908cca624db6ae65
--- /dev/null
+++ b/model/model - Copie.py	
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils.poolers import Pooling
+from dgl.nn import EdgeGATConv, GlobalAttentionPooling
+from torch.nn import GRUCell
+import dgl
+
+
+
+
+
+
+class GNN_DTDG(nn.Module):
+    def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device, mlp_layers, number_snapshot):
+        super(GNN_DTDG, self).__init__()
+        self.encoder = GNN_RNN(n_dim, e_dim, hidden_dim, out_dim ,n_layers, n_heads, device, number_snapshot)
+        self.decoder = GNN_RNN(out_dim, e_dim, hidden_dim, n_dim ,n_layers, n_heads, device, number_snapshot)
+        self.number_snapshot = number_snapshot
+        self.classifier_layers = nn.ModuleList([
+        ])
+
+        for _ in range(mlp_layers - 1):
+            self.classifier_layers.extend([
+                nn.Linear(out_dim, out_dim).to(device),
+                nn.ReLU(),
+            ])
+        self.classifier_layers.extend([
+            nn.Linear(out_dim, 1).to(device),
+            nn.Sigmoid()
+        ])
+        
+        self.pooling_gate_nn = nn.Linear(out_dim , 1)
+        self.pooling = GlobalAttentionPooling(self.pooling_gate_nn)
+        self.pooler = Pooling("mean")
+        self.encoder_to_decoder = nn.Linear( out_dim, out_dim, bias=False)
+        
+    def forward(self, g):   
+        encoded = self.encoder(g)
+        new_g = []
+        i=  0
+        for G in g:
+            g_encoded = G.clone()
+            g_encoded.ndata["attr"] = self.encoder_to_decoder(encoded[i])
+            new_g.append(g_encoded)
+            i+=1
+
+        
+
+        decoded = self.decoder(new_g)
+        return decoded[-1]
+       # x = self.pooler(G, embeddings, [1])[0] 
+       # h_g = x.clone()
+      #  for layer in self.classifier_layers:
+       #     x = layer(x)
+    
+    def embed(self, g):
+        return self.encoder(g)[-1]
+
+
+
+class GNN_RNN(nn.Module):
+    def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device,  number_snapshot):
+        super(GNN_RNN, self).__init__()
+        self.device = device
+        self.gnn_layers = nn.ModuleList([EdgeGATConv(in_feats=n_dim, edge_feats=e_dim, out_feats=out_dim, num_heads=n_heads, allow_zero_in_degree=True).to(device)])
+        
+        self.out_dim = out_dim
+        
+        for _ in range(n_layers-1):
+            self.gnn_layers.append(
+            EdgeGATConv(in_feats=out_dim, edge_feats=e_dim, out_feats=out_dim, num_heads=n_heads, allow_zero_in_degree=True).to(device)
+            )
+        
+        self.rnn_layers = nn.ModuleList([])
+
+        for _ in range(number_snapshot):
+            self.rnn_layers.append(
+                GRUCell(out_dim, out_dim, device = device)
+            )
+        
+        self.classifier_layers = nn.ModuleList([
+        ])
+
+
+    def forward(self, g):        
+        i = 0
+        H_s = []
+        for G in g:
+            
+            with G.to(self.device).local_scope():
+                x = G.ndata["attr"].float()
+                e = G.edata["attr"].float()
+                for layer in self.gnn_layers:
+                    r = layer(G, x, e)
+                    x = torch.mean(r,dim=1).to(self.device)
+                    del r
+
+                #if ( i == 0):
+               #     H = self.rnn_layers[i](x, x)
+               # else:
+              #      H = self.rnn_layers[i](x, H)
+                
+                H = x
+                H_s.append(H)
+                embeddings = H.clone()
+                i+=1
+                #x = self.pooling(g[0], x)[0]
+
+        
+        return H_s
\ No newline at end of file
diff --git a/model/model.py b/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0389eeae588c10742143ba62a97e9b8f5639eb47
--- /dev/null
+++ b/model/model.py
@@ -0,0 +1,157 @@
+import torch
+import torch.nn as nn 
+from utils.poolers import Pooling
+from .loss_func import sce_loss
+from .gat import GAT
+from .rnn import RNN_Cells
+from utils.utils import create_norm
+from functools import partial
+
+
+class STGNN_AutoEncoder(nn.Module):
+    def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device, number_snapshot, activation, feat_drop, negative_slope, residual, norm, pooling, loss_fn="sce", alpha_l=2, use_all_hidden = True):
+        super(STGNN_AutoEncoder, self).__init__()
+
+        #Initialize the encoder and decoder structure
+        self.encoder = STGNN(n_dim, e_dim, out_dim, out_dim, n_layers, n_heads, n_heads, number_snapshot, activation, feat_drop, negative_slope, residual, norm, True, use_all_hidden, device)
+        self.decoder = STGNN(out_dim, e_dim, out_dim, n_dim, 1, n_heads, 1, number_snapshot, activation, feat_drop, negative_slope, residual, norm, False, False, device)
+        
+
+        # Linear layer for mapping encoder output to decoder input
+        if (use_all_hidden):
+            self.encoder_to_decoder = nn.Linear(n_layers * out_dim, out_dim, bias=False)
+        else:
+            self.encoder_to_decoder = nn.Linear(out_dim, out_dim, bias=False)
+
+            # Additional components and parameters
+        self.n_layers = n_layers
+        self.pooler = Pooling(pooling)
+        self.number_snapshot = number_snapshot
+        self.use_all_hidden = use_all_hidden
+        self.device = device
+        self.criterion = self.setup_loss_fn(loss_fn, alpha_l)
+
+    def setup_loss_fn(self, loss_fn, alpha_l):
+        if loss_fn == "sce":
+            criterion = partial(sce_loss, alpha=alpha_l)
+        
+        elif loss_fn == "ce":
+            criterion = nn.CrossEntropyLoss()
+        elif loss_fn == "mse":
+            criterion = nn.MSELoss()
+        elif loss_fn == "mae":
+            criterion = nn.L1Loss()
+        else:
+            raise NotImplementedError
+        return criterion    
+    
+
+    def forward(self, g):   
+
+        # Encode input graphs
+        node_features = []  
+        new_t = []
+        for G in g:
+            new_g = G.clone()
+            node_features.append(new_g.ndata['attr'].float())
+            new_g.edata['attr'] = new_g.edata['attr'].float()
+            new_t.append(new_g)
+        final_embedding = self.encoder(new_t, node_features)
+        encoding = []
+        if (self.use_all_hidden):
+            for i in range(len(g)):
+                conca = [final_embedding[j][i] for j in range(len(final_embedding))]
+                encoding.append(torch.cat(conca,dim=1))
+        else:
+            encoding = final_embedding[0]
+
+        node_features = []
+        for encoded in encoding:
+            encoded = self.encoder_to_decoder(encoded)
+            node_features.append(encoded)
+
+        reconstructed = self.decoder(new_t, node_features)
+        recon = reconstructed[0][-1]
+        x_init = g[0].ndata['attr'].float()
+        loss = self.criterion(recon, x_init)
+
+        return loss
+    
+    def embed(self, g):
+        node_features= []
+        for G in g:
+            node_features.append(G.ndata['attr'].float())
+    
+        return self.encoder.embed(g, node_features)
+    
+
+class STGNN(nn.Module):
+
+    def __init__(self, input_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, n_heads_out, n_snapshot, activation, feat_drop, negative_slope, residual, norm, encoding, use_all_hidden, device):
+        super(STGNN, self).__init__()
+
+        if encoding:
+            out = out_dim // n_heads
+            hidden = out_dim // n_heads
+        else:
+            hidden = hidden_dim
+            out = out_dim
+
+        self.gnn = GAT(
+            n_dim=input_dim,
+            e_dim=e_dim,
+            hidden_dim=hidden,
+            out_dim=out,
+            n_layers=n_layers,
+            n_heads=n_heads,
+            n_heads_out=n_heads_out,
+            concat_out=True,
+            activation=activation,
+            feat_drop=feat_drop,
+            attn_drop=0.0,
+            negative_slope=negative_slope,
+            residual=residual,
+            norm=create_norm(norm),
+            encoding=encoding,
+        )
+        self.rnn = RNN_Cells(out_dim, out_dim, n_snapshot, device)
+        self.use_all_hidden = use_all_hidden
+
+    def forward(self, G, node_features):
+
+        embeddings = []
+        for i in range(len(G)):
+            g = G[i]
+            if (self.use_all_hidden):
+                node_embedding, all_hidden = self.gnn(g, node_features[i], return_hidden = self.use_all_hidden)
+                embeddings.append(all_hidden)
+                n_iter = len(all_hidden)
+
+            else:
+                embeddings.append(self.gnn(g, node_features[i], return_hidden = self.use_all_hidden))
+                n_iter = 1
+            
+        result = []
+        for j in range(n_iter):
+            encoding = []
+
+            for embedding in embeddings :
+                if (self.use_all_hidden):
+                    encoding.append(embedding[j])
+                else:
+                    encoding.append(embedding)
+
+            result.append(self.rnn(encoding))
+
+        return result
+    
+
+    def embed(self, G, node_features):
+        embeddings = []
+        for i in range(len(G)):
+            g = G[i].clone()
+            g.edata['attr'] = g.edata['attr'].float()
+            embedding = self.gnn(g, node_features[i], return_hidden = False)
+            embeddings.append(embedding)
+
+        return self.rnn(embeddings)[-1]
\ No newline at end of file
diff --git a/model/rnn.py b/model/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..663634fbdfbba966894f74241faed35876657f5c
--- /dev/null
+++ b/model/rnn.py
@@ -0,0 +1,23 @@
+from torch.nn import GRUCell
+import torch.nn as nn 
+
+
+class RNN_Cells(nn.Module):   
+    def __init__(self, input_dim, hidden_dim, n_cells, device) :
+        super(RNN_Cells, self).__init__()
+        self.cells = nn.ModuleList()
+        
+        for i in range(n_cells):
+            self.cells.append(GRUCell(input_dim, hidden_dim, device=device))
+
+
+    def forward(self, inputs):
+
+        results = []
+        for i in range(len(self.cells)):
+            if (i == 0):
+                results.append(self.cells[i](inputs[i], inputs[i]))
+            else:
+                results.append(self.cells[i](inputs[i], results[i-1]))
+
+        return results
\ No newline at end of file
diff --git a/model/train_entity.py b/model/train_entity.py
new file mode 100644
index 0000000000000000000000000000000000000000..58aa15140e6204fbc48d865036ea37eae251858a
--- /dev/null
+++ b/model/train_entity.py
@@ -0,0 +1,29 @@
+from tqdm import tqdm
+from utils.dataloader import  load_entity_level_dataset
+import torch
+
+
+
+
+def entity_level_train(model,  snapshot, optimizer, max_epoch, device,  dataset_name, train_data):
+    
+        model = model.to(device)
+        model.train()
+        epoch_iter = tqdm(range(max_epoch))
+        for epoch in epoch_iter:
+            epoch_loss = 0.0
+            for i in train_data:
+                g = load_entity_level_dataset(dataset_name, 'train', i, snapshot, device)
+                model.train()
+                loss = model(g)
+                loss /= len(train_data)
+                optimizer.zero_grad()
+                epoch_loss += loss.item()
+                loss.backward()
+                optimizer.step()
+                
+                del g
+            epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
+        torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name))
+
+        return model
\ No newline at end of file
diff --git a/model/train_graph.py b/model/train_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ffbd9869fab466b92497c4815db704f9adc1727
--- /dev/null
+++ b/model/train_graph.py
@@ -0,0 +1,50 @@
+import numpy as np
+from tqdm import tqdm
+import torch
+from utils.dataloader import load_graph
+import matplotlib.pyplot as plt
+from model.eval import batch_level_evaluation
+
+def batch_level_train(model,  train_loader, optimizer, max_epoch, device, n_dim, e_dim, dataset_name, validation=True):
+    
+    epoch_iter = tqdm(range(max_epoch))
+    model.to(device)  # Move model to GPU
+    n_epoch = 0
+    validation_f1 = []
+    loss_global = []
+    for epoch in epoch_iter:
+        model.train()
+        loss_list = []
+        for iter, batch in enumerate(train_loader):
+            batch_g = [load_graph(int(idx), dataset_name, device) for idx in batch]  # Move data to GPU
+            model.train()
+            g = batch_g[0]
+            loss = model(g) 
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+            loss_list.append(loss.item())
+            del batch_g, g
+
+        n_epoch +=1
+
+        if (validation):
+            validation_f1.append(batch_level_evaluation(model, model.pooler, device, ['knn'], dataset_name, n_dim, e_dim)[0])
+
+        loss_global.append(np.mean(loss_list))
+        torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name))
+        epoch_iter.set_description(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f}")
+
+    if (validation):
+        plt.plot(list(range(n_epoch)), validation_f1, label='Graph 2', marker='o', linestyle='-')
+        plt.plot(list(range(n_epoch)), loss_global, label='Graph 2', marker='x', linestyle='--')
+        plt.xlabel('X-axis')
+        plt.ylabel('Y-axis')
+        plt.title('Two Graphs on the Same Plot')
+
+        # Add a legend
+        plt.legend()
+
+        # Display the plot
+        plt.show()
+    return model
diff --git a/result/FedAvg-2client-2round-cadets.pt b/result/FedAvg-2client-2round-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..97537e00b83117c06200ca53c0e22cef67b81fe8
Binary files /dev/null and b/result/FedAvg-2client-2round-cadets.pt differ
diff --git a/result/FedAvg-2client-cadets.pt b/result/FedAvg-2client-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..b9fb122af9072298b859056fa94be04932ef7a8b
Binary files /dev/null and b/result/FedAvg-2client-cadets.pt differ
diff --git a/result/FedAvg-4client-cadets.pt b/result/FedAvg-4client-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..4fb2cd9db49d52015b81ce437063abfa0c17d065
Binary files /dev/null and b/result/FedAvg-4client-cadets.pt differ
diff --git a/result/FedAvg-SC2.pt b/result/FedAvg-SC2.pt
new file mode 100644
index 0000000000000000000000000000000000000000..89411388841175b247e424a74463231d45fe1391
Binary files /dev/null and b/result/FedAvg-SC2.pt differ
diff --git a/result/FedAvg-Unicorn-Cadets.pt b/result/FedAvg-Unicorn-Cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..161ba7a56e45a3d1c33bed20b12bf8564b50a8c8
Binary files /dev/null and b/result/FedAvg-Unicorn-Cadets.pt differ
diff --git a/result/FedAvg-cadets-e3.pt b/result/FedAvg-cadets-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..4afaa90ee421747fa19d39673f9b63b06eb9804f
Binary files /dev/null and b/result/FedAvg-cadets-e3.pt differ
diff --git a/result/FedAvg-cadets.pt b/result/FedAvg-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..011e3d240dd0a104c0f4bc54541cb31c8cab5fcb
Binary files /dev/null and b/result/FedAvg-cadets.pt differ
diff --git a/result/FedAvg-clearscope-e3.pt b/result/FedAvg-clearscope-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..136dc3c1eef647a32d2189a8896551aa1b5df162
Binary files /dev/null and b/result/FedAvg-clearscope-e3.pt differ
diff --git a/result/FedAvg-streamspot.pt b/result/FedAvg-streamspot.pt
new file mode 100644
index 0000000000000000000000000000000000000000..82308fe0b86721ef8790eff5354ec0e74388d70e
Binary files /dev/null and b/result/FedAvg-streamspot.pt differ
diff --git a/result/FedAvg-theia-e3.pt b/result/FedAvg-theia-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5a995ad213ddd3883f7c8f000095fd6821410d2b
Binary files /dev/null and b/result/FedAvg-theia-e3.pt differ
diff --git a/result/FedAvg-theia.pt b/result/FedAvg-theia.pt
new file mode 100644
index 0000000000000000000000000000000000000000..66b64fb4c8b42b811c28d7c37238c88d5122dbd2
Binary files /dev/null and b/result/FedAvg-theia.pt differ
diff --git a/result/FedAvg-trace-e3.pt b/result/FedAvg-trace-e3.pt
new file mode 100644
index 0000000000000000000000000000000000000000..4dca8e988169b150f8d1692ad751a5bb051e5609
Binary files /dev/null and b/result/FedAvg-trace-e3.pt differ
diff --git a/result/FedAvg-wget-long.pt b/result/FedAvg-wget-long.pt
new file mode 100644
index 0000000000000000000000000000000000000000..9e7504eb16446ab46f327bcc30f13d77b051f75e
Binary files /dev/null and b/result/FedAvg-wget-long.pt differ
diff --git a/result/FedAvg-wget.pt b/result/FedAvg-wget.pt
new file mode 100644
index 0000000000000000000000000000000000000000..51c0ecdfba6589f6d63433368fb10c5810705765
Binary files /dev/null and b/result/FedAvg-wget.pt differ
diff --git a/result/FedAvg_Streamspot-streamspot.pt b/result/FedAvg_Streamspot-streamspot.pt
new file mode 100644
index 0000000000000000000000000000000000000000..36015f47e0574d828dcc17178ebd5afbf8daaead
Binary files /dev/null and b/result/FedAvg_Streamspot-streamspot.pt differ
diff --git a/result/FedOpt-cadets.pt b/result/FedOpt-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..291831ac85f89a3b16648d005a814a064f1c2cdf
Binary files /dev/null and b/result/FedOpt-cadets.pt differ
diff --git a/result/FedOpt-theia.pt b/result/FedOpt-theia.pt
new file mode 100644
index 0000000000000000000000000000000000000000..383ebdb7544ff96e65dc408b857d5f61803b14e4
Binary files /dev/null and b/result/FedOpt-theia.pt differ
diff --git a/result/FedOpt-trace.pt b/result/FedOpt-trace.pt
new file mode 100644
index 0000000000000000000000000000000000000000..3e7faf31ca32212b8158216eaab0be77c2696ada
Binary files /dev/null and b/result/FedOpt-trace.pt differ
diff --git a/result/FedOpt_Streamspot.pt b/result/FedOpt_Streamspot.pt
new file mode 100644
index 0000000000000000000000000000000000000000..36015f47e0574d828dcc17178ebd5afbf8daaead
Binary files /dev/null and b/result/FedOpt_Streamspot.pt differ
diff --git a/result/FedProx-cadets.pt b/result/FedProx-cadets.pt
new file mode 100644
index 0000000000000000000000000000000000000000..deaf520dbef791e5e8520370443acffe65e8825b
Binary files /dev/null and b/result/FedProx-cadets.pt differ
diff --git a/result/FedProx-theia.pt b/result/FedProx-theia.pt
new file mode 100644
index 0000000000000000000000000000000000000000..100b9f885317fcc267d0fb44a53ca3928c908a4e
Binary files /dev/null and b/result/FedProx-theia.pt differ
diff --git a/result/FedProx-trace.pt b/result/FedProx-trace.pt
new file mode 100644
index 0000000000000000000000000000000000000000..b34e70d94c6e557d0c2f92efd1d2082076393c38
Binary files /dev/null and b/result/FedProx-trace.pt differ
diff --git a/result/FedProx_Streamspot.pt b/result/FedProx_Streamspot.pt
new file mode 100644
index 0000000000000000000000000000000000000000..232f436cc87c609d59174ee35b8bebc1fcda303b
Binary files /dev/null and b/result/FedProx_Streamspot.pt differ
diff --git a/result/checkpoint-trace.pt b/result/checkpoint-trace.pt
new file mode 100644
index 0000000000000000000000000000000000000000..880702e81259a697a9a794bb943927b2327b7420
Binary files /dev/null and b/result/checkpoint-trace.pt differ
diff --git a/save_results/.gitkeep b/save_results/.gitkeep
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/save_results/distance_save_trace_FedProx.pkl b/save_results/distance_save_trace_FedProx.pkl
deleted file mode 100644
index e88de86f4bcb805e0475a41d4de3bbfde5d7353e..0000000000000000000000000000000000000000
Binary files a/save_results/distance_save_trace_FedProx.pkl and /dev/null differ
diff --git a/test.py b/test.py
deleted file mode 100644
index c1f6ebe728019b94bfcd9513907c28a9d39eaa81..0000000000000000000000000000000000000000
--- a/test.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split
-
-
-(
-        train_data_num,
-        val_data_num,
-        test_data_num,
-        train_data_global,
-        val_data_global,
-        test_data_global,
-        data_local_num_dict,
-        train_data_local_dict,
-        val_data_local_dict,
-        test_data_local_dict,
-    ) = load_partition_data(None, 3, "streamspot") 
diff --git a/train.py b/train.py
deleted file mode 100644
index 968181a1cfaa6de944115c62b2b187376c734502..0000000000000000000000000000000000000000
--- a/train.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import os
-import random
-import torch
-import warnings
-from tqdm import tqdm
-from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata, transform_graph
-from model.autoencoder import build_model
-from torch.utils.data.sampler import SubsetRandomSampler
-from dgl.dataloading import GraphDataLoader
-import dgl
-from model.train import batch_level_train
-from utils.utils import set_random_seed, create_optimizer
-from utils.config import build_args
-warnings.filterwarnings('ignore')
-
-
-def extract_dataloaders(entries, batch_size):
-    random.shuffle(entries)
-    train_idx = torch.arange(len(entries))
-    train_sampler = SubsetRandomSampler(train_idx)
-    train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler)
-    return train_loader
-
-
-def main(main_args):
-    device = "cpu"
-    dataset_name =  "trace"
-    if dataset_name == 'streamspot':
-        main_args.num_hidden = 256
-        main_args.max_epoch = 5
-        main_args.num_layers = 4
-    elif dataset_name == 'wget':
-        main_args.num_hidden = 256
-        main_args.max_epoch = 2
-        main_args.num_layers = 4
-    else:
-        main_args["num_hidden"] = 64
-        main_args["max_epoch"] = 50
-        main_args["num_layers"] = 3
-    set_random_seed(0)
-
-    if dataset_name == 'streamspot' or dataset_name == 'wget':
-        if dataset_name == 'streamspot':
-            batch_size = 12
-        else:
-            batch_size = 1
-        dataset = load_batch_level_dataset(dataset_name)
-        n_node_feat = dataset['n_feat']
-        n_edge_feat = dataset['e_feat']
-        graphs = dataset['dataset']
-        train_index = dataset['train_index']
-        main_args.n_dim = n_node_feat
-        main_args.e_dim = n_edge_feat
-        model = build_model(main_args)
-        model = model.to(device)
-        optimizer = create_optimizer(main_args.optimizer, model, main_args.lr, main_args.weight_decay)
-        model = batch_level_train(model, graphs, (extract_dataloaders(train_index, batch_size)),
-                                  optimizer, main_args.max_epoch, device, main_args.n_dim, main_args.e_dim)
-        torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name))
-    else:
-        metadata = load_metadata(dataset_name)
-        main_args["n_dim"] = metadata['node_feature_dim']
-        main_args["e_dim"] = metadata['edge_feature_dim']
-        model = build_model(main_args)
-        model = model.to(device)
-        model.train()
-        optimizer = create_optimizer(main_args["optimizer"], model, main_args["lr"], main_args["weight_decay"])
-        epoch_iter = tqdm(range(main_args["max_epoch"]))
-        n_train = metadata['n_train']
-        for epoch in epoch_iter:
-            epoch_loss = 0.0
-            for i in range(n_train):
-                g = load_entity_level_dataset(dataset_name, 'train', i).to(device)
-                model.train()
-                loss  = model(g)
-                loss /= n_train
-                optimizer.zero_grad()
-                epoch_loss += loss.item()
-                loss.backward()
-                optimizer.step()
-                del g
-            epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
-            torch.save(model.state_dict(), "./result/checkpoint-{}.pt".format(dataset_name))
-
-    return
-
-
-if __name__ == '__main__':
-    args = build_args()
-    main(args)
diff --git a/trainer/.gitkeep b/trainer/.gitkeep
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/trainer/__pycache__/magic_aggregator.cpython-310.pyc b/trainer/__pycache__/magic_aggregator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..173a82a3b45a20cb705440e95a4e6d09bf98930b
Binary files /dev/null and b/trainer/__pycache__/magic_aggregator.cpython-310.pyc differ
diff --git a/trainer/__pycache__/magic_aggregator.cpython-311.pyc b/trainer/__pycache__/magic_aggregator.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4c63a27702b466871bad2b0f18230d3620cd143
Binary files /dev/null and b/trainer/__pycache__/magic_aggregator.cpython-311.pyc differ
diff --git a/trainer/__pycache__/magic_trainer.cpython-310.pyc b/trainer/__pycache__/magic_trainer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5aa1877df83f3deb6862f39988826b001e5273f8
Binary files /dev/null and b/trainer/__pycache__/magic_trainer.cpython-310.pyc differ
diff --git a/trainer/__pycache__/magic_trainer.cpython-311.pyc b/trainer/__pycache__/magic_trainer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a42939be3e8c7c9ffc0d3985c7b46efe4ab53d0f
Binary files /dev/null and b/trainer/__pycache__/magic_trainer.cpython-311.pyc differ
diff --git a/trainer/__pycache__/single_trainer.cpython-311.pyc b/trainer/__pycache__/single_trainer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d343331aa188a304c811aafdd0d1814cc17efe9
Binary files /dev/null and b/trainer/__pycache__/single_trainer.cpython-311.pyc differ
diff --git a/trainer/magic_aggregator.py b/trainer/magic_aggregator.py
index a9d7f70c78d64347b655494de696050e82d0fa6f..a5d09149bf53e9180f58802853aa4e6eca33c438 100644
--- a/trainer/magic_aggregator.py
+++ b/trainer/magic_aggregator.py
@@ -6,11 +6,15 @@ import wandb
 from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
 from utils.config import build_args
 from fedml.core import ServerAggregator
+from model.train_graph import batch_level_train
+from model.train_entity import entity_level_train
 from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
+
+
+from utils.utils import set_random_seed, create_optimizer
 from utils.poolers import Pooling
-# Trainer for MoleculeNet. The evaluation metric is ROC-AUC
-from data.data_loader import load_batch_level_dataset_main, load_metadata, load_entity_level_dataset
-from utils.loaddata import  load_batch_level_dataset
+from utils.config import build_args
+from utils.dataloader import load_data, load_entity_level_dataset, load_metadata
 
 class MagicWgetAggregator(ServerAggregator):
     def __init__(self, model, args, name):
@@ -62,25 +66,30 @@ class MagicWgetAggregator(ServerAggregator):
             logging.info("Models match perfectly! :)")
 
     def _test(self, test_data, device, args):
-        args = build_args()           
-        if (self.name == 'wget' or self.name == 'streamspot'):
-            logging.info("----------test--------")
-                 
+        main_args = build_args()        
+        dataset_name = self.name
+        nsnapshot = args.snapshot
+        if (self.name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']):
+            logging.info("----------test--------") 
+            set_random_seed(0)
+            dataset = load_data(dataset_name)
+            n_node_feat = dataset['n_feat']
+            n_edge_feat = dataset['e_feat']
+                #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148]
+            main_args.n_dim = n_node_feat
+            main_args.e_dim = n_edge_feat
             model = self.model
             model.eval()
             model.to(device)
-            pooler = Pooling(args["pooling"])
-            dataset = load_batch_level_dataset(self.name)
-            n_node_feat = dataset['n_feat']
-            n_edge_feat = dataset['e_feat']
-            args["n_dim"] = n_node_feat
-            args["e_dim"] = n_edge_feat
-            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"],  args["e_dim"])
+            model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device))
+            model = model.to(device)
+            pooler = Pooling(main_args.pooling)
+            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], dataset_name, main_args.n_dim,
+                                                                    main_args.e_dim)
         else:
-            torch.save(self.model.state_dict(), "./result/FedAvg-4client-{}.pt".format(self.name))
             metadata = load_metadata(self.name)
-            args["n_dim"] = metadata['node_feature_dim']
-            args["e_dim"] = metadata['edge_feature_dim']
+            main_args.n_dim = metadata['node_feature_dim']
+            main_args.e_dim = metadata['edge_feature_dim']
             model = self.model.to(device)
             model.eval()
             malicious, _ = metadata['malicious']
@@ -90,17 +99,17 @@ class MagicWgetAggregator(ServerAggregator):
             with torch.no_grad():
                 x_train = []
                 for i in range(n_train):
-                   g = load_entity_level_dataset(self.name, 'train', i).to(device)
+                   g = load_entity_level_dataset(dataset_name, 'train', i, nsnapshot, device)
                    x_train.append(model.embed(g).cpu().detach().numpy())
                    del g
                 x_train = np.concatenate(x_train, axis=0)
                 skip_benign = 0
             x_test = []
             for i in range(n_test):
-                g = load_entity_level_dataset(self.name, 'test', i).to(device)
+                g = load_entity_level_dataset(self.name, 'test', i, nsnapshot, device)
                 # Exclude training samples from the test set
                 if i != n_test - 1:
-                    skip_benign += g.number_of_nodes()
+                    skip_benign += g[0].number_of_nodes()
                 x_test.append(model.embed(g).cpu().detach().numpy())
                 del g
             x_test = np.concatenate(x_test, axis=0)
diff --git a/trainer/magic_trainer.py b/trainer/magic_trainer.py
index 99c9a5b49e31a631f30053db0c3cf0623569b2eb..02c428c2459ec140dee5be82bbab64fcf0532c33 100644
--- a/trainer/magic_trainer.py
+++ b/trainer/magic_trainer.py
@@ -10,18 +10,18 @@ import wandb
 from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
 
 from fedml.core import ClientTrainer
-from model.autoencoder import build_model
 from torch.utils.data.sampler import SubsetRandomSampler
 from dgl.dataloading import GraphDataLoader
-from model.train import batch_level_train
+
+from model.train_graph import batch_level_train
+from model.train_entity import entity_level_train
+from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
+
+
 from utils.utils import set_random_seed, create_optimizer
 from utils.poolers import Pooling
 from utils.config import build_args
-from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata
-from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
-from model.autoencoder import build_model
-from data.data_loader import transform_data
-from utils.loaddata import load_batch_level_dataset
+from utils.dataloader import load_data, load_entity_level_dataset, load_metadata
 
 # Trainer for MoleculeNet. The evaluation metric is ROC-AUC
 def extract_dataloaders(entries, batch_size):
@@ -47,77 +47,47 @@ class MagicTrainer(ClientTrainer):
         self.model.load_state_dict(model_parameters)
 
     def train(self, train_data, device, args):
-        test_data = None
-        args = build_args()
-        if (self.name == "wget"):
-            args["num_hidden"] = 256
-            args["max_epoch"] = 2
-            args["num_layers"] = 4
-            batch_size = 1
-        elif (self.name == "streamspot"):
-            args["num_hidden"] = 256
-            args["max_epoch"] = 5
-            args["num_layers"] = 4
-            batch_size = 12
-        else:
-            args["num_hidden"] = 64
-            args["max_epoch"] = 50
-            args["num_layers"] = 3
-        
-        max_test_score = 0
-        best_model_params = {}       
-        if (self.name == 'wget' or self.name == 'streamspot'):
+        main_args = build_args()
+        dataset_name = self.name
+        input('start')
+        if (dataset_name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']):
+            if (dataset_name == 'wget' or dataset_name == 'streamspot' or dataset_name == 'Unicorn-Cadets'  or dataset_name == 'clearscope-e3'):
+                batch_size = 1
+                main_args.max_epoch = 6
+
+            elif (dataset_name == 'SC2' or dataset_name == 'wget-long'):
+                batch_size = 1
+                main_args.max_epoch = 1
  
-            dataset = load_batch_level_dataset(self.name)
-            data = transform_data(train_data)
+        
+            dataset = load_data(dataset_name)
             n_node_feat = dataset['n_feat']
             n_edge_feat = dataset['e_feat']
-            graphs = data['dataset']
-            train_index = data['train_index']
-            args["n_dim"] = n_node_feat
-            args["e_dim"] = n_edge_feat
-        #self.model = build_model(args)
-            self.model = self.model.to(device)
-            optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"])
-            self.model = batch_level_train(self.model, graphs, (extract_dataloaders(train_index, batch_size)),
-                                  optimizer, args["max_epoch"], device, n_node_feat,   n_edge_feat)
-            test_score, _ = self.test(test_data, device, args)
+            #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148]
+            validation_index = dataset['validation_index']
+            label = dataset["labels"]
+            main_args.n_dim = n_node_feat
+            main_args.e_dim = n_edge_feat
+            main_args.optimizer = "adamw"
+            set_random_seed(0)
+            #model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device))
+            optimizer = create_optimizer(main_args.optimizer, self.model, main_args.lr, main_args.weight_decay)
+            train_loader = extract_dataloaders(train_data[0], batch_size)
+            self.model = batch_level_train(self.model,  train_loader, optimizer, main_args.max_epoch, device, main_args.n_dim, main_args.e_dim, dataset_name, validation= False)
         else:
-            
-            metadata = load_metadata(self.name)
-            args["n_dim"] = metadata['node_feature_dim']
-            args["e_dim"] = metadata['edge_feature_dim']
-            self.model = self.model.to(device)
-            self.model.train()
-            optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"])
-            epoch_iter = tqdm(range(args["max_epoch"]))
-            n_train = len(train_data[0])
-            input("start?")
-            for epoch in epoch_iter:
-                epoch_loss = 0.0
-                for i in range(n_train):
-                    g = train_data[0][i]
-                    self.model.train()
-                    loss  = self.model(g)
-                    loss /= n_train
-                    optimizer.zero_grad()
-                    epoch_loss += loss.item()
-                    loss.backward(retain_graph=True)
-                    optimizer.step()
-                    del g
-            epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
-        if (self.name == 'wget' or self.name == 'streamspot'):    
-            test_score, _ = self.test(test_data, device, args)
-            if test_score > self.max:
-                self.max = test_score
-                best_model_params = {
-                k: v.cpu() for k, v in self.model.state_dict().items()
-                }
-        else:
-            self.max = 0
-            best_model_params = {
+            main_args.max_epoch = 50            
+            nsnapshot = args.snapshot
+            main_args.optimizer = "adam"
+            set_random_seed(0)
+            #model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device))
+            optimizer = create_optimizer(main_args.optimizer, self.model, main_args.lr, main_args.weight_decay)
+            #train_loader = extract_dataloaders(train_index, batch_size)
+            self.model = entity_level_train(self.model, nsnapshot, optimizer, main_args.max_epoch, device, dataset_name, train_data[0])
+
+        self.max = 0
+        best_model_params = {
             k: v.cpu() for k, v in self.model.state_dict().items()
-            }
+        }
             
      
 
@@ -126,17 +96,31 @@ class MagicTrainer(ClientTrainer):
     def test(self, test_data, device, args):
         if (self.name == 'wget' or self.name == 'streamspot'):
             logging.info("----------test--------")
-            args = build_args()        
+            main_args = build_args()        
+            dataset_name = self.name
+            if dataset_name in ['streamspot', 'wget', 'SC2']:
+                main_args.num_hidden = 256
+                main_args.num_layers = 4
+                main_args.max_epoch = 50
+            else:
+                main_args.num_hidden = 64
+                main_args.num_layers = 3
+            set_random_seed(0)
+            dataset = load_data(dataset_name, 1, 0.6, 0.2)
+            n_node_feat = dataset['n_feat']
+            n_edge_feat = dataset['e_feat']
+                #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148]
+            main_args.n_dim = n_node_feat
+            main_args.e_dim = n_edge_feat
             model = self.model
             model.eval()
             model.to(device)
-            pooler = Pooling(args["pooling"])
-            dataset = load_batch_level_dataset(self.name)
-            n_node_feat = dataset['n_feat']
-            n_edge_feat = dataset['e_feat']
-            args["n_dim"] = n_node_feat
-            args["e_dim"] = n_edge_feat
-            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"],  args["e_dim"])
+            model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device))
+            model = model.to(device)
+            pooler = Pooling(main_args.pooling)
+            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], dataset_name, main_args.n_dim,
+                                                                    main_args.e_dim)
+            
         else:
             metadata = load_metadata(self.name)
             args["n_dim"] = metadata['node_feature_dim']
diff --git a/trainer/utils/.gitkeep b/trainer/utils/.gitkeep
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/trainer/utils/config.py b/trainer/utils/config.py
deleted file mode 100644
index bb9c1ef28d385b80c7673027b4f69b31c36b1175..0000000000000000000000000000000000000000
--- a/trainer/utils/config.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import argparse
-
-
-def build_args():
-    parser = argparse.ArgumentParser(description="MAGIC")
-    parser.add_argument("--dataset", type=str, default="wget")
-    parser.add_argument("--device", type=int, default=-1)
-    parser.add_argument("--lr", type=float, default=0.001,
-                        help="learning rate")
-    parser.add_argument("--weight_decay", type=float, default=5e-4,
-                        help="weight decay")
-    parser.add_argument("--negative_slope", type=float, default=0.2,
-                        help="the negative slope of leaky relu for GAT")
-    parser.add_argument("--mask_rate", type=float, default=0.5)
-    parser.add_argument("--alpha_l", type=float, default=3, help="`pow`inddex for `sce` loss")
-    parser.add_argument("--optimizer", type=str, default="adam")
-    parser.add_argument("--loss_fn", type=str, default='sce')
-    parser.add_argument("--pooling", type=str, default="mean")
-    args = parser.parse_args()
-    return args
diff --git a/trainer/utils/loaddata.py b/trainer/utils/loaddata.py
deleted file mode 100644
index 41e7dfc03ee39adcc1f5729d59aa21124d981fff..0000000000000000000000000000000000000000
--- a/trainer/utils/loaddata.py
+++ /dev/null
@@ -1,197 +0,0 @@
-import pickle as pkl
-import time
-import torch.nn.functional as F
-import dgl
-import networkx as nx
-import json
-from tqdm import tqdm
-import os
-
-
-class StreamspotDataset(dgl.data.DGLDataset):
-    def process(self):
-        pass
-
-    def __init__(self, name):
-        super(StreamspotDataset, self).__init__(name=name)
-        if name == 'streamspot':
-            path = './data/streamspot'
-            num_graphs = 600
-            self.graphs = []
-            self.labels = []
-            print('Loading {} dataset...'.format(name))
-            for i in tqdm(range(num_graphs)):
-                idx = i
-                g = dgl.from_networkx(
-                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))),
-                    node_attrs=['type'],
-                    edge_attrs=['type']
-                )
-                self.graphs.append(g)
-                if 300 <= idx <= 399:
-                    self.labels.append(1)
-                else:
-                    self.labels.append(0)
-        else:
-            raise NotImplementedError
-
-    def __getitem__(self, i):
-        return self.graphs[i], self.labels[i]
-
-    def __len__(self):
-        return len(self.graphs)
-
-
-class WgetDataset(dgl.data.DGLDataset):
-    def process(self):
-        pass
-
-    def __init__(self, name):
-        super(WgetDataset, self).__init__(name=name)
-        if name == 'wget':
-            path = './data/wget/final'
-            num_graphs = 150
-            self.graphs = []
-            self.labels = []
-            print('Loading {} dataset...'.format(name))
-            for i in tqdm(range(num_graphs)):
-                idx = i
-                g = dgl.from_networkx(
-                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx))))),
-                    node_attrs=['type'],
-                    edge_attrs=['type']
-                )
-                self.graphs.append(g)
-                if 0 <= idx <= 24:
-                    self.labels.append(1)
-                else:
-                    self.labels.append(0)
-        else:
-            raise NotImplementedError
-
-    def __getitem__(self, i):
-        return self.graphs[i], self.labels[i]
-
-    def __len__(self):
-        return len(self.graphs)
-
-
-def load_rawdata(name):
-    if name == 'streamspot':
-        path = './data/streamspot'
-        if os.path.exists(path + '/graphs.pkl'):
-            print('Loading processed {} dataset...'.format(name))
-            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
-        else:
-            raw_data = StreamspotDataset(name)
-            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
-    elif name == 'wget':
-        path = './data/wget'
-        if os.path.exists(path + '/graphs.pkl'):
-            print('Loading processed {} dataset...'.format(name))
-            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
-        else:
-            raw_data = WgetDataset(name)
-            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
-    else:
-        raise NotImplementedError
-    return raw_data
-
-
-def load_batch_level_dataset(dataset_name):
-    dataset = load_rawdata(dataset_name)
-    graph, _ = dataset[0]
-    node_feature_dim = 0
-    for g, _ in dataset:
-        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
-    edge_feature_dim = 0
-    for g, _ in dataset:
-        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
-    node_feature_dim += 1
-    edge_feature_dim += 1
-    full_dataset = [i for i in range(len(dataset))]
-    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
-    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
-
-    return {'dataset': dataset,
-            'train_index': train_dataset,
-            'full_index': full_dataset,
-            'n_feat': node_feature_dim,
-            'e_feat': edge_feature_dim}
-
-
-def transform_graph(g, node_feature_dim, edge_feature_dim):
-    new_g = g.clone()
-    new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float()
-    new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float()
-    return new_g
-
-
-def preload_entity_level_dataset(path):
-    path = './data/' + path
-    if os.path.exists(path + '/metadata.json'):
-        pass
-    else:
-        print('transforming')
-        train_gs = [dgl.from_networkx(
-            nx.node_link_graph(g),
-            node_attrs=['type'],
-            edge_attrs=['type']
-        ) for g in pkl.load(open(path + '/train.pkl', 'rb'))]
-        print('transforming')
-        test_gs = [dgl.from_networkx(
-            nx.node_link_graph(g),
-            node_attrs=['type'],
-            edge_attrs=['type']
-        ) for g in pkl.load(open(path + '/test.pkl', 'rb'))]
-        malicious = pkl.load(open(path + '/malicious.pkl', 'rb'))
-
-        node_feature_dim = 0
-        for g in train_gs:
-            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
-        for g in test_gs:
-            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
-        node_feature_dim += 1
-        edge_feature_dim = 0
-        for g in train_gs:
-            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
-        for g in test_gs:
-            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
-        edge_feature_dim += 1
-        result_test_gs = []
-        for g in test_gs:
-            g = transform_graph(g, node_feature_dim, edge_feature_dim)
-            result_test_gs.append(g)
-        result_train_gs = []
-        for g in train_gs:
-            g = transform_graph(g, node_feature_dim, edge_feature_dim)
-            result_train_gs.append(g)
-        metadata = {
-            'node_feature_dim': node_feature_dim,
-            'edge_feature_dim': edge_feature_dim,
-            'malicious': malicious,
-            'n_train': len(result_train_gs),
-            'n_test': len(result_test_gs)
-        }
-        with open(path + '/metadata.json', 'w', encoding='utf-8') as f:
-            json.dump(metadata, f)
-        for i, g in enumerate(result_train_gs):
-            with open(path + '/train{}.pkl'.format(i), 'wb') as f:
-                pkl.dump(g, f)
-        for i, g in enumerate(result_test_gs):
-            with open(path + '/test{}.pkl'.format(i), 'wb') as f:
-                pkl.dump(g, f)
-
-
-def load_metadata(path):
-    preload_entity_level_dataset(path)
-    with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f:
-        metadata = json.load(f)
-    return metadata
-
-
-def load_entity_level_dataset(path, t, n):
-    preload_entity_level_dataset(path)
-    with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f:
-        data = pkl.load(f)
-    return data
diff --git a/trainer/utils/poolers.py b/trainer/utils/poolers.py
deleted file mode 100644
index 4b9dc4aee1eac42f578952596a5e400adbd1e391..0000000000000000000000000000000000000000
--- a/trainer/utils/poolers.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import torch.nn as nn
-
-
-class Pooling(nn.Module):
-    def __init__(self, pooler):
-        super(Pooling, self).__init__()
-        self.pooler = pooler
-
-    def forward(self, graph, feat, t=None):
-        feat = feat
-        # Implement node type-specific pooling
-        with graph.local_scope():
-            if t is None:
-                if self.pooler == 'mean':
-                    return feat.mean(0, keepdim=True)
-                elif self.pooler == 'sum':
-                    return feat.sum(0, keepdim=True)
-                elif self.pooler == 'max':
-                    return feat.max(0, keepdim=True)
-                else:
-                    raise NotImplementedError
-            elif isinstance(t, int):
-                mask = (graph.ndata['type'] == t)
-                if self.pooler == 'mean':
-                    return feat[mask].mean(0, keepdim=True)
-                elif self.pooler == 'sum':
-                    return feat[mask].sum(0, keepdim=True)
-                elif self.pooler == 'max':
-                    return feat[mask].max(0, keepdim=True)
-                else:
-                    raise NotImplementedError
-            else:
-                mask = (graph.ndata['type'] == t[0])
-                for i in range(1, len(t)):
-                    mask |= (graph.ndata['type'] == t[i])
-                if self.pooler == 'mean':
-                    return feat[mask].mean(0, keepdim=True)
-                elif self.pooler == 'sum':
-                    return feat[mask].sum(0, keepdim=True)
-                elif self.pooler == 'max':
-                    return feat[mask].max(0, keepdim=True)
-                else:
-                    raise NotImplementedError
diff --git a/trainer/utils/streamspot_parser.py b/trainer/utils/streamspot_parser.py
deleted file mode 100644
index 04438eb0f6336ae2bad2015ad6caf4919f48f9a3..0000000000000000000000000000000000000000
--- a/trainer/utils/streamspot_parser.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import networkx as nx
-from tqdm import tqdm
-import json
-raw_path = '../data/streamspot/'
-
-NUM_GRAPHS = 600
-node_type_dict = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
-edge_type_dict = ['i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
-                  'q', 't', 'u', 'v', 'w', 'y', 'z', 'A', 'C', 'D', 'E', 'G']
-node_type_set = set(node_type_dict)
-edge_type_set = set(edge_type_dict)
-
-count_graph = 0
-with open(raw_path + 'all.tsv', 'r', encoding='utf-8') as f:
-    lines = f.readlines()
-    g = nx.DiGraph()
-    node_map = {}
-    count_node = 0
-    for line in tqdm(lines):
-        src, src_type, dst, dst_type, etype, graph_id = line.strip('\n').split('\t')
-        graph_id = int(graph_id)
-        if src_type not in node_type_set or dst_type not in node_type_set:
-            continue
-        if etype not in edge_type_set:
-            continue
-        if graph_id != count_graph:
-            count_graph += 1
-            for n in g.nodes():
-                g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
-            for e in g.edges():
-                g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
-            f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
-            json.dump(nx.node_link_data(g), f1)
-            assert graph_id == count_graph
-            g = nx.DiGraph()
-            count_node = 0
-        if src not in node_map:
-            node_map[src] = count_node
-            g.add_node(count_node, type=src_type)
-            count_node += 1
-        if dst not in node_map:
-            node_map[dst] = count_node
-            g.add_node(count_node, type=dst_type)
-            count_node += 1
-        if not g.has_edge(node_map[src], node_map[dst]):
-            g.add_edge(node_map[src], node_map[dst], type=etype)
-    count_graph += 1
-    for n in g.nodes():
-        g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
-    for e in g.edges():
-        g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
-    f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
-    json.dump(nx.node_link_data(g), f1)
diff --git a/trainer/utils/trace_parser.py b/trainer/utils/trace_parser.py
deleted file mode 100644
index 76a91d424a3eea0bbec87ffb635a0d98b27d72f7..0000000000000000000000000000000000000000
--- a/trainer/utils/trace_parser.py
+++ /dev/null
@@ -1,288 +0,0 @@
-import argparse
-import json
-import os
-import random
-import re
-
-from tqdm import tqdm
-import networkx as nx
-import pickle as pkl
-
-
-node_type_dict = {}
-edge_type_dict = {}
-node_type_cnt = 0
-edge_type_cnt = 0
-
-metadata = {
-    'trace':{
-        'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'],
-        'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4']
-    },
-    'theia':{
-            'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'],
-            'test': ['ta1-theia-e3-official-6r.json.8']
-    },
-    'cadets':{
-            'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'],
-            'test': ['ta1-cadets-e3-official-2.json']
-    }
-}
-
-
-pattern_uuid = re.compile(r'uuid\":\"(.*?)\"')
-pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
-pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
-pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
-pattern_type = re.compile(r'type\":\"(.*?)\"')
-pattern_time = re.compile(r'timestampNanos\":(.*?),')
-pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"')
-pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"')
-pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"')
-
-adversarial = 'None'
-
-
-def read_single_graph(dataset, malicious, path, test=False):
-    global node_type_cnt, edge_type_cnt
-    g = nx.DiGraph()
-    print('converting {} ...'.format(path))
-    path = '../data/{}/'.format(dataset) + path + '.txt'
-    f = open(path, 'r')
-    lines = []
-    for l in f.readlines():
-        split_line = l.split('\t')
-        src, src_type, dst, dst_type, edge_type, ts = split_line
-        ts = int(ts)
-        if not test:
-            if src in malicious or dst in malicious:
-                if src in malicious and src_type != 'MemoryObject':
-                    continue
-                if dst in malicious and dst_type != 'MemoryObject':
-                    continue
-
-        if src_type not in node_type_dict:
-            node_type_dict[src_type] = node_type_cnt
-            node_type_cnt += 1
-        if dst_type not in node_type_dict:
-            node_type_dict[dst_type] = node_type_cnt
-            node_type_cnt += 1
-        if edge_type not in edge_type_dict:
-            edge_type_dict[edge_type] = edge_type_cnt
-            edge_type_cnt += 1
-        if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type:
-            lines.append([dst, src, dst_type, src_type, edge_type, ts])
-        else:
-            lines.append([src, dst, src_type, dst_type, edge_type, ts])
-    lines.sort(key=lambda l: l[5])
-
-    node_map = {}
-    node_type_map = {}
-    node_cnt = 0
-    node_list = []
-    for l in lines:
-        src, dst, src_type, dst_type, edge_type = l[:5]
-        src_type_id = node_type_dict[src_type]
-        dst_type_id = node_type_dict[dst_type]
-        edge_type_id = edge_type_dict[edge_type]
-        if src not in node_map:
-            if test:
-                if adversarial == 'MFE' or adversarial == 'MCE':
-                    if src in malicious:
-                        src_type_id = node_type_dict['FILE_OBJECT_FILE']
-            if not test:
-                if adversarial == 'BFP':
-                    if src not in malicious:
-                        i = random.randint(1, 20)
-                        if i == 1:
-                            src_type_id = node_type_dict['NetFlowObject']
-            node_map[src] = node_cnt
-            g.add_node(node_cnt, type=src_type_id)
-            node_list.append(src)
-            node_type_map[src] = src_type
-            node_cnt += 1
-        if dst not in node_map:
-            if test:
-                if adversarial == 'MFE' or adversarial == 'MCE':
-                    if dst in malicious:
-                        dst_type_id = node_type_dict['FILE_OBJECT_FILE']
-            if not test:
-                if adversarial == 'BFP':
-                    if dst not in malicious:
-                        i = random.randint(1, 20)
-                        if i == 1:
-                            dst_type_id = node_type_dict['NetFlowObject']
-            node_map[dst] = node_cnt
-            g.add_node(node_cnt, type=dst_type_id)
-            node_type_map[dst] = dst_type
-            node_list.append(dst)
-            node_cnt += 1
-        if not g.has_edge(node_map[src], node_map[dst]):
-            g.add_edge(node_map[src], node_map[dst], type=edge_type_id)
-    if (adversarial == 'MSE' or adversarial == 'MCE') and test:
-        for i, node in enumerate(node_list):
-            if node in malicious:
-                while True:
-                    another_node = random.choice(node_list)
-                    if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS':
-                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE'])
-                        print(node, another_node)
-                        break
-                    if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE':
-                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ'])
-                        print(node, another_node)
-                        break
-                    if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS':
-                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT'])
-                        print(node, another_node)
-                        break
-                    if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject':
-                        break
-    return node_map, g
-
-
-def preprocess_dataset(dataset):
-    id_nodetype_map = {}
-    id_nodename_map = {}
-    for file in os.listdir('../data/{}/'.format(dataset)):
-        if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file:
-            print('reading {} ...'.format(file))
-            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
-            for line in tqdm(f):
-                if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue
-                if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue
-                if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue
-                if len(pattern_uuid.findall(line)) == 0: print(line)
-                uuid = pattern_uuid.findall(line)[0]
-                subject_type = pattern_type.findall(line)
-
-                if len(subject_type) < 1:
-                    if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line:
-                        subject_type = 'MemoryObject'
-                    if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line:
-                        subject_type = 'NetFlowObject'
-                    if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line:
-                        subject_type = 'UnnamedPipeObject'
-                else:
-                    subject_type = subject_type[0]
-
-                if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']:
-                    continue
-                id_nodetype_map[uuid] = subject_type
-                if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0:
-                    id_nodename_map[uuid] = pattern_file_name.findall(line)[0]
-                elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0:
-                    id_nodename_map[uuid] = pattern_process_name.findall(line)[0]
-                elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0:
-                    id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0]
-    for key in metadata[dataset]:
-        for file in metadata[dataset][key]:
-            if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'):
-                continue
-            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
-            fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8')
-            print('processing {} ...'.format(file))
-            for line in tqdm(f):
-                if 'com.bbn.tc.schema.avro.cdm18.Event' in line:
-                    edgeType = pattern_type.findall(line)[0]
-                    timestamp = pattern_time.findall(line)[0]
-                    srcId = pattern_src.findall(line)
-
-                    if len(srcId) == 0: continue
-                    srcId = srcId[0]
-                    if not srcId in id_nodetype_map:
-                        continue
-                    srcType = id_nodetype_map[srcId]
-                    dstId1 = pattern_dst1.findall(line)
-                    if len(dstId1) > 0 and dstId1[0] != 'null':
-                        dstId1 = dstId1[0]
-                        if not dstId1 in id_nodetype_map:
-                            continue
-                        dstType1 = id_nodetype_map[dstId1]
-                        this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str(
-                            dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
-                        fw.write(this_edge1)
-
-                    dstId2 = pattern_dst2.findall(line)
-                    if len(dstId2) > 0 and dstId2[0] != 'null':
-                        dstId2 = dstId2[0]
-                        if not dstId2 in id_nodetype_map.keys():
-                            continue
-                        dstType2 = id_nodetype_map[dstId2]
-                        this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str(
-                            dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
-                        fw.write(this_edge2)
-            fw.close()
-            f.close()
-    if len(id_nodename_map) != 0:
-        fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8')
-        json.dump(id_nodename_map, fw)
-    if len(id_nodetype_map) != 0:
-        fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8')
-        json.dump(id_nodetype_map, fw)
-
-
-def read_graphs(dataset):
-    malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset)
-    f = open(malicious_entities, 'r')
-    malicious_entities = set()
-    for l in f.readlines():
-        malicious_entities.add(l.lstrip().rstrip())
-
-    preprocess_dataset(dataset)
-    train_gs = []
-    for file in metadata[dataset]['train']:
-        _, train_g = read_single_graph(dataset, malicious_entities, file, False)
-        train_gs.append(train_g)
-    test_gs = []
-    test_node_map = {}
-    count_node = 0
-    for file in metadata[dataset]['test']:
-        node_map, test_g = read_single_graph(dataset, malicious_entities, file, True)
-        assert len(node_map) == test_g.number_of_nodes()
-        test_gs.append(test_g)
-        for key in node_map:
-            if key not in test_node_map:
-                test_node_map[key] = node_map[key] + count_node
-        count_node += test_g.number_of_nodes()
-
-    if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)):
-        with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f:
-            id_nodename_map = json.load(f)
-        with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f:
-            id_nodetype_map = json.load(f)
-        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
-        final_malicious_entities = []
-        malicious_names = []
-        for e in malicious_entities:
-            if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject':
-                final_malicious_entities.append(test_node_map[e])
-                if e in id_nodename_map:
-                    malicious_names.append(id_nodename_map[e])
-                    f.write('{}\t{}\n'.format(e, id_nodename_map[e]))
-                else:
-                    malicious_names.append(e)
-                    f.write('{}\t{}\n'.format(e, e))
-    else:
-        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
-        final_malicious_entities = []
-        malicious_names = []
-        for e in malicious_entities:
-            if e in test_node_map:
-                final_malicious_entities.append(test_node_map[e])
-                malicious_names.append(e)
-                f.write('{}\t{}\n'.format(e, e))
-
-    pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb'))
-    pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb'))
-    pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb'))
-
-
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='CDM Parser')
-    parser.add_argument("--dataset", type=str, default="trace")
-    args = parser.parse_args()
-    if args.dataset not in ['trace', 'theia', 'cadets']:
-        raise NotImplementedError
-    read_graphs(args.dataset)
-
diff --git a/trainer/utils/utils.py b/trainer/utils/utils.py
deleted file mode 100644
index bcf9a481407696d73f30fe1dde279154a05702b3..0000000000000000000000000000000000000000
--- a/trainer/utils/utils.py
+++ /dev/null
@@ -1,112 +0,0 @@
-import torch
-import torch.nn as nn
-from functools import partial
-import numpy as np
-import random
-import torch.optim as optim
-
-
-def create_optimizer(opt, model, lr, weight_decay):
-    opt_lower = opt.lower()
-    parameters = model.parameters()
-    opt_args = dict(lr=lr, weight_decay=weight_decay)
-    optimizer = None
-    opt_split = opt_lower.split("_")
-    opt_lower = opt_split[-1]
-    if opt_lower == "adam":
-        optimizer = optim.Adam(parameters, **opt_args)
-    elif opt_lower == "adamw":
-        optimizer = optim.AdamW(parameters, **opt_args)
-    elif opt_lower == "adadelta":
-        optimizer = optim.Adadelta(parameters, **opt_args)
-    elif opt_lower == "radam":
-        optimizer = optim.RAdam(parameters, **opt_args)
-    elif opt_lower == "sgd":
-        opt_args["momentum"] = 0.9
-        return optim.SGD(parameters, **opt_args)
-    else:
-        assert False and "Invalid optimizer"
-    return optimizer
-
-
-def random_shuffle(x, y):
-    idx = list(range(len(x)))
-    random.shuffle(idx)
-    return x[idx], y[idx]
-
-
-def set_random_seed(seed):
-    random.seed(seed)
-    np.random.seed(seed)
-    torch.manual_seed(seed)
-    torch.cuda.manual_seed(seed)
-    torch.cuda.manual_seed_all(seed)
-    torch.backends.cudnn.determinstic = True
-
-
-def create_activation(name):
-    if name == "relu":
-        return nn.ReLU()
-    elif name == "gelu":
-        return nn.GELU()
-    elif name == "prelu":
-        return nn.PReLU()
-    elif name is None:
-        return nn.Identity()
-    elif name == "elu":
-        return nn.ELU()
-    else:
-        raise NotImplementedError(f"{name} is not implemented.")
-
-
-def create_norm(name):
-    if name == "layernorm":
-        return nn.LayerNorm
-    elif name == "batchnorm":
-        return nn.BatchNorm1d
-    elif name == "graphnorm":
-        return partial(NormLayer, norm_type="groupnorm")
-    else:
-        return None
-
-
-class NormLayer(nn.Module):
-    def __init__(self, hidden_dim, norm_type):
-        super().__init__()
-        if norm_type == "batchnorm":
-            self.norm = nn.BatchNorm1d(hidden_dim)
-        elif norm_type == "layernorm":
-            self.norm = nn.LayerNorm(hidden_dim)
-        elif norm_type == "graphnorm":
-            self.norm = norm_type
-            self.weight = nn.Parameter(torch.ones(hidden_dim))
-            self.bias = nn.Parameter(torch.zeros(hidden_dim))
-
-            self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
-        else:
-            raise NotImplementedError
-
-    def forward(self, graph, x):
-        tensor = x
-        if self.norm is not None and type(self.norm) != str:
-            return self.norm(tensor)
-        elif self.norm is None:
-            return tensor
-
-        batch_list = graph.batch_num_nodes
-        batch_size = len(batch_list)
-        batch_list = torch.Tensor(batch_list).long().to(tensor.device)
-        batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
-        batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
-        mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
-        mean = mean.scatter_add_(0, batch_index, tensor)
-        mean = (mean.T / batch_list).T
-        mean = mean.repeat_interleave(batch_list, dim=0)
-
-        sub = tensor - mean * self.mean_scale
-
-        std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
-        std = std.scatter_add_(0, batch_index, sub.pow(2))
-        std = ((std.T / batch_list).T + 1e-6).sqrt()
-        std = std.repeat_interleave(batch_list, dim=0)
-        return self.weight * sub / std + self.bias
diff --git a/trainer/utils/wget_parser.py b/trainer/utils/wget_parser.py
deleted file mode 100644
index 3ffbcffda1c5e9176cf1e5219b8c1e0ed1574055..0000000000000000000000000000000000000000
--- a/trainer/utils/wget_parser.py
+++ /dev/null
@@ -1,820 +0,0 @@
-import argparse
-import json
-import os
-
-import xxhash
-import tqdm
-import logging
-import networkx as nx
-from tqdm import tqdm
-import time
-import datetime
-
-
-valid_node_type = ['file', 'process_memory', 'task', 'mmaped_file', 'path', 'socket', 'address', 'link']
-CONSOLE_ARGUMENTS = None
-
-
-def hashgen(l):
-    """Generate a single hash value from a list. @l is a list of
-    string values, which can be properties of a node/edge. This
-    function returns a single hashed integer value."""
-    hasher = xxhash.xxh64()
-    for e in l:
-        hasher.update(e)
-    return hasher.intdigest()
-
-
-def parse_nodes(json_string, node_map):
-    """Parse a CamFlow JSON string that may contain nodes ("activity" or "entity").
-    Parsed nodes populate @node_map, which is a dictionary that maps the node's UID,
-    which is assigned by CamFlow to uniquely identify a node object, to a hashed
-    value (in str) which represents the 'type' of the node. """
-    json_object = None
-    try:
-        # use "ignore" if non-decodeable exists in the @json_string
-        json_object = json.loads(json_string)
-    except Exception as e:
-        print("Exception ({}) occurred when parsing a node in JSON:".format(e))
-        print(json_string)
-        exit(1)
-    if "activity" in json_object:
-        activity = json_object["activity"]
-        for uid in activity:
-            if not uid in node_map:  # only parse unseen nodes
-                if "prov:type" not in activity[uid]:
-                    # a node must have a type.
-                    # record this issue if logging is turned on
-                    if CONSOLE_ARGUMENTS.verbose:
-                        logging.debug("skipping a problematic activity node with no 'prov:type': {}".format(uid))
-                else:
-                    node_map[uid] = activity[uid]["prov:type"]
-
-    if "entity" in json_object:
-        entity = json_object["entity"]
-        for uid in entity:
-            if not uid in node_map:
-                if "prov:type" not in entity[uid]:
-                    if CONSOLE_ARGUMENTS.verbose:
-                        logging.debug("skipping a problematic entity node with no 'prov:type': {}".format(uid))
-                else:
-                    node_map[uid] = entity[uid]["prov:type"]
-
-
-def parse_all_nodes(filename, node_map):
-    """Parse all nodes in CamFlow data. @filename is the file path of
-    the CamFlow data to parse. @node_map contains the mappings of all
-    CamFlow nodes to their hashed attributes. """
-    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing nodes in CamFlow data from {}'.format(filename)
-    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
-    with open(filename, 'r') as f:
-        # each line in CamFlow data could contain multiple
-        # provenance nodes, we call @parse_nodes routine.
-        for line in f:
-            pb.update()  # for progress tracking
-            parse_nodes(line, node_map)
-    f.close()
-    pb.close()
-
-
-def parse_all_edges(inputfile, outputfile, node_map, noencode):
-    """Parse all edges (including their timestamp) from CamFlow data file @inputfile
-    to an @outputfile. Before this function is called, parse_all_nodes should be called
-    to populate the @node_map for all nodes in the CamFlow file. If @noencode is set,
-    we do not hash the nodes' original UUIDs generated by CamFlow to integers. This
-    function returns the total number of valid edge parsed from CamFlow dataset.
-
-    The output edgelist has the following format for each line, if -s is not set:
-        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>
-    If -s is set, each line would look like:
-        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>:<timestamp_stats>"""
-    total_edges = 0
-    smallest_timestamp = None
-    # scan through the entire file to find the smallest timestamp from all the edges.
-    # this step is only needed if we need to add some statistical information.
-    if CONSOLE_ARGUMENTS.stats:
-        description = '\x1b[6;30;42m[STATUS]\x1b[0m Scanning edges in CamFlow data from {}'.format(inputfile)
-        pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
-        with open(inputfile, 'r') as f:
-            for line in f:
-                pb.update()
-                json_object = json.loads(line)
-
-                if "used" in json_object:
-                    used = json_object["used"]
-                    for uid in used:
-                        if "prov:type" not in used[uid]:
-                            continue
-                        if "cf:date" not in used[uid]:
-                            continue
-                        if "prov:entity" not in used[uid]:
-                            continue
-                        if "prov:activity" not in used[uid]:
-                            continue
-                        srcUUID = used[uid]["prov:entity"]
-                        dstUUID = used[uid]["prov:activity"]
-                        if srcUUID not in node_map:
-                            continue
-                        if dstUUID not in node_map:
-                            continue
-                        timestamp_str = used[uid]["cf:date"]
-                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                        if smallest_timestamp == None or ts < smallest_timestamp:
-                            smallest_timestamp = ts
-
-                if "wasGeneratedBy" in json_object:
-                    wasGeneratedBy = json_object["wasGeneratedBy"]
-                    for uid in wasGeneratedBy:
-                        if "prov:type" not in wasGeneratedBy[uid]:
-                            continue
-                        if "cf:date" not in wasGeneratedBy[uid]:
-                            continue
-                        if "prov:entity" not in wasGeneratedBy[uid]:
-                            continue
-                        if "prov:activity" not in wasGeneratedBy[uid]:
-                            continue
-                        srcUUID = wasGeneratedBy[uid]["prov:activity"]
-                        dstUUID = wasGeneratedBy[uid]["prov:entity"]
-                        if srcUUID not in node_map:
-                            continue
-                        if dstUUID not in node_map:
-                            continue
-                        timestamp_str = wasGeneratedBy[uid]["cf:date"]
-                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                        if smallest_timestamp == None or ts < smallest_timestamp:
-                            smallest_timestamp = ts
-
-                if "wasInformedBy" in json_object:
-                    wasInformedBy = json_object["wasInformedBy"]
-                    for uid in wasInformedBy:
-                        if "prov:type" not in wasInformedBy[uid]:
-                            continue
-                        if "cf:date" not in wasInformedBy[uid]:
-                            continue
-                        if "prov:informant" not in wasInformedBy[uid]:
-                            continue
-                        if "prov:informed" not in wasInformedBy[uid]:
-                            continue
-                        srcUUID = wasInformedBy[uid]["prov:informant"]
-                        dstUUID = wasInformedBy[uid]["prov:informed"]
-                        if srcUUID not in node_map:
-                            continue
-                        if dstUUID not in node_map:
-                            continue
-                        timestamp_str = wasInformedBy[uid]["cf:date"]
-                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                        if smallest_timestamp == None or ts < smallest_timestamp:
-                            smallest_timestamp = ts
-
-                if "wasDerivedFrom" in json_object:
-                    wasDerivedFrom = json_object["wasDerivedFrom"]
-                    for uid in wasDerivedFrom:
-                        if "prov:type" not in wasDerivedFrom[uid]:
-                            continue
-                        if "cf:date" not in wasDerivedFrom[uid]:
-                            continue
-                        if "prov:usedEntity" not in wasDerivedFrom[uid]:
-                            continue
-                        if "prov:generatedEntity" not in wasDerivedFrom[uid]:
-                            continue
-                        srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
-                        dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
-                        if srcUUID not in node_map:
-                            continue
-                        if dstUUID not in node_map:
-                            continue
-                        timestamp_str = wasDerivedFrom[uid]["cf:date"]
-                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                        if smallest_timestamp == None or ts < smallest_timestamp:
-                            smallest_timestamp = ts
-
-                if "wasAssociatedWith" in json_object:
-                    wasAssociatedWith = json_object["wasAssociatedWith"]
-                    for uid in wasAssociatedWith:
-                        if "prov:type" not in wasAssociatedWith[uid]:
-                            continue
-                        if "cf:date" not in wasAssociatedWith[uid]:
-                            continue
-                        if "prov:agent" not in wasAssociatedWith[uid]:
-                            continue
-                        if "prov:activity" not in wasAssociatedWith[uid]:
-                            continue
-                        srcUUID = wasAssociatedWith[uid]["prov:agent"]
-                        dstUUID = wasAssociatedWith[uid]["prov:activity"]
-                        if srcUUID not in node_map:
-                            continue
-                        if dstUUID not in node_map:
-                            continue
-                        timestamp_str = wasAssociatedWith[uid]["cf:date"]
-                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                        if smallest_timestamp == None or ts < smallest_timestamp:
-                            smallest_timestamp = ts
-        f.close()
-        pb.close()
-
-    # we will go through the CamFlow data (again) and output edgelist to a file
-    output = open(outputfile, "w+")
-    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in CamFlow data from {}'.format(inputfile)
-    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
-    with open(inputfile, 'r') as f:
-        for line in f:
-            pb.update()
-            json_object = json.loads(line)
-
-            if "used" in json_object:
-                used = json_object["used"]
-                for uid in used:
-                    if "prov:type" not in used[uid]:
-                        # an edge must have a type; if not,
-                        # we will have to skip the edge. Log
-                        # this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (used) record without type: {}".format(uid))
-                        continue
-                    else:
-                        edgetype = "used"
-                    # cf:id is used as logical timestamp to order edges
-                    if "cf:id" not in used[uid]:
-                        # an edge must have a logical timestamp;
-                        # if not we will have to skip the edge.
-                        # Log this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (used) record without logical timestamp: {}".format(uid))
-                        continue
-                    else:
-                        timestamp = used[uid]["cf:id"]
-                    if "prov:entity" not in used[uid]:
-                        # an edge's source node must exist;
-                        # if not, we will have to skip the
-                        # edge. Log this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug(
-                                "edge (used/{}) record without source UUID: {}".format(used[uid]["prov:type"], uid))
-                        continue
-                    if "prov:activity" not in used[uid]:
-                        # an edge's destination node must exist;
-                        # if not, we will have to skip the edge.
-                        # Log this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug(
-                                "edge (used/{}) record without destination UUID: {}".format(used[uid]["prov:type"],
-                                                                                            uid))
-                        continue
-                    srcUUID = used[uid]["prov:entity"]
-                    dstUUID = used[uid]["prov:activity"]
-                    # both source and destination node must
-                    # exist in @node_map; if not, we will
-                    # have to skip the edge. Log this issue
-                    # if verbose is set.
-                    if srcUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug(
-                                "edge (used/{}) record with an unseen srcUUID: {}".format(used[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        srcVal = node_map[srcUUID]
-                    if dstUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug(
-                                "edge (used/{}) record with an unseen dstUUID: {}".format(used[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        dstVal = node_map[dstUUID]
-                    if "cf:date" not in used[uid]:
-                        # an edge must have a timestamp; if
-                        # not, we will have to skip the edge.
-                        # Log this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (used) record without timestamp: {}".format(uid))
-                        continue
-                    else:
-                        # we only record @adjusted_ts if we need
-                        # to record stats of CamFlow dataset.
-                        if CONSOLE_ARGUMENTS.stats:
-                            ts_str = used[uid]["cf:date"]
-                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                            adjusted_ts = ts - smallest_timestamp
-                    if "cf:jiffies" not in used[uid]:
-                        # an edge must have a jiffies timestamp; if
-                        # not, we will have to skip the edge.
-                        # Log this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (used) record without jiffies: {}".format(uid))
-                        continue
-                    else:
-                        # we only record @jiffies if
-                        # the option is set
-                        if CONSOLE_ARGUMENTS.jiffies:
-                            jiffies = used[uid]["cf:jiffies"]
-                    total_edges += 1
-                    if noencode:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp))
-
-            if "wasGeneratedBy" in json_object:
-                wasGeneratedBy = json_object["wasGeneratedBy"]
-                for uid in wasGeneratedBy:
-                    if "prov:type" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy) record without type: {}".format(uid))
-                        continue
-                    else:
-                        edgetype = "wasGeneratedBy"
-                    if "cf:id" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy) record without logical timestamp: {}".format(uid))
-                        continue
-                    else:
-                        timestamp = wasGeneratedBy[uid]["cf:id"]
-                    if "prov:entity" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy/{}) record without source UUID: {}".format(
-                                wasGeneratedBy[uid]["prov:type"], uid))
-                        continue
-                    if "prov:activity" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy/{}) record without destination UUID: {}".format(
-                                wasGeneratedBy[uid]["prov:type"], uid))
-                        continue
-                    srcUUID = wasGeneratedBy[uid]["prov:activity"]
-                    dstUUID = wasGeneratedBy[uid]["prov:entity"]
-                    if srcUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy/{}) record with an unseen srcUUID: {}".format(
-                                wasGeneratedBy[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        srcVal = node_map[srcUUID]
-                    if dstUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy/{}) record with an unsen dstUUID: {}".format(
-                                wasGeneratedBy[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        dstVal = node_map[dstUUID]
-                    if "cf:date" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy) record without timestamp: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            ts_str = wasGeneratedBy[uid]["cf:date"]
-                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                            adjusted_ts = ts - smallest_timestamp
-                    if "cf:jiffies" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy) record without jiffies: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.jiffies:
-                            jiffies = wasGeneratedBy[uid]["cf:jiffies"]
-                    total_edges += 1
-                    if noencode:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
-                                                               edgetype, timestamp))
-
-            if "wasInformedBy" in json_object:
-                wasInformedBy = json_object["wasInformedBy"]
-                for uid in wasInformedBy:
-                    if "prov:type" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy) record without type: {}".format(uid))
-                        continue
-                    else:
-                        edgetype = "wasInformedBy"
-                    if "cf:id" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy) record without logical timestamp: {}".format(uid))
-                        continue
-                    else:
-                        timestamp = wasInformedBy[uid]["cf:id"]
-                    if "prov:informant" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy/{}) record without source UUID: {}".format(
-                                wasInformedBy[uid]["prov:type"], uid))
-                        continue
-                    if "prov:informed" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy/{}) record without destination UUID: {}".format(
-                                wasInformedBy[uid]["prov:type"], uid))
-                        continue
-                    srcUUID = wasInformedBy[uid]["prov:informant"]
-                    dstUUID = wasInformedBy[uid]["prov:informed"]
-                    if srcUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy/{}) record with an unseen srcUUID: {}".format(
-                                wasInformedBy[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        srcVal = node_map[srcUUID]
-                    if dstUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy/{}) record with an unseen dstUUID: {}".format(
-                                wasInformedBy[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        dstVal = node_map[dstUUID]
-                    if "cf:date" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy) record without timestamp: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            ts_str = wasInformedBy[uid]["cf:date"]
-                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                            adjusted_ts = ts - smallest_timestamp
-                    if "cf:jiffies" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy) record without jiffies: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.jiffies:
-                            jiffies = wasInformedBy[uid]["cf:jiffies"]
-                    total_edges += 1
-                    if noencode:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
-                                                               edgetype, timestamp))
-
-            if "wasDerivedFrom" in json_object:
-                wasDerivedFrom = json_object["wasDerivedFrom"]
-                for uid in wasDerivedFrom:
-                    if "prov:type" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom) record without type: {}".format(uid))
-                        continue
-                    else:
-                        edgetype = "wasDerivedFrom"
-                    if "cf:id" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom) record without logical timestamp: {}".format(uid))
-                            continue
-                    else:
-                        timestamp = wasDerivedFrom[uid]["cf:id"]
-                    if "prov:usedEntity" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom/{}) record without source UUID: {}".format(
-                                wasDerivedFrom[uid]["prov:type"], uid))
-                        continue
-                    if "prov:generatedEntity" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom/{}) record without destination UUID: {}".format(
-                                wasDerivedFrom[uid]["prov:type"], uid))
-                        continue
-                    srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
-                    dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
-                    if srcUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen srcUUID: {}".format(
-                                wasDerivedFrom[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        srcVal = node_map[srcUUID]
-                    if dstUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen dstUUID: {}".format(
-                                wasDerivedFrom[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        dstVal = node_map[dstUUID]
-                    if "cf:date" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom) record without timestamp: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            ts_str = wasDerivedFrom[uid]["cf:date"]
-                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                            adjusted_ts = ts - smallest_timestamp
-                    if "cf:jiffies" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom) record without jiffies: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.jiffies:
-                            jiffies = wasDerivedFrom[uid]["cf:jiffies"]
-                    total_edges += 1
-                    if noencode:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
-                                                               edgetype, timestamp))
-
-            if "wasAssociatedWith" in json_object:
-                wasAssociatedWith = json_object["wasAssociatedWith"]
-                for uid in wasAssociatedWith:
-                    if "prov:type" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith) record without type: {}".format(uid))
-                        continue
-                    else:
-                        edgetype = "wasAssociatedWith"
-                    if "cf:id" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith) record without logical timestamp: {}".format(uid))
-                        continue
-                    else:
-                        timestamp = wasAssociatedWith[uid]["cf:id"]
-                    if "prov:agent" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith/{}) record without source UUID: {}".format(
-                                wasAssociatedWith[uid]["prov:type"], uid))
-                        continue
-                    if "prov:activity" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith/{}) record without destination UUID: {}".format(
-                                wasAssociatedWith[uid]["prov:type"], uid))
-                        continue
-                    srcUUID = wasAssociatedWith[uid]["prov:agent"]
-                    dstUUID = wasAssociatedWith[uid]["prov:activity"]
-                    if srcUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen srcUUID: {}".format(
-                                wasAssociatedWith[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        srcVal = node_map[srcUUID]
-                    if dstUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen dstUUID: {}".format(
-                                wasAssociatedWith[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        dstVal = node_map[dstUUID]
-                    if "cf:date" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith) record without timestamp: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            ts_str = wasAssociatedWith[uid]["cf:date"]
-                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                            adjusted_ts = ts - smallest_timestamp
-                    if "cf:jiffies" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith) record without jiffies: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.jiffies:
-                            jiffies = wasAssociatedWith[uid]["cf:jiffies"]
-                    total_edges += 1
-                    if noencode:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
-                                                               edgetype, timestamp))
-    f.close()
-    output.close()
-    pb.close()
-    return total_edges
-
-def read_single_graph(file_name, threshold):
-    graph = []
-    edge_cnt = 0
-    with open(file_name, 'r') as f:
-        for line in f:
-            try:
-                edge = line.strip().split("\t")
-                new_edge = [edge[0], edge[1]]
-                attributes = edge[2].strip().split(":")
-                source_node_type = attributes[0]
-                destination_node_type = attributes[1]
-                edge_type = attributes[2]
-                edge_order = attributes[3]
-
-                new_edge.append(source_node_type)
-                new_edge.append(destination_node_type)
-                new_edge.append(edge_type)
-                new_edge.append(edge_order)
-                graph.append(new_edge)
-                edge_cnt += 1
-            except:
-                print("{}".format(line))
-    f.close()
-    graph.sort(key=lambda e: e[5])
-    if len(graph) <= threshold:
-        return graph
-    else:
-        return graph[:threshold]
-
-
-def process_graph(name, threshold):
-    graph = read_single_graph(name, threshold)
-    result_graph = nx.DiGraph()
-    cnt = 0
-    for num, edge in enumerate(graph):
-        cnt += 1
-        src, dst, src_type, dst_type, edge_type = edge[:5]
-        if True:# src_type in valid_node_type and dst_type in valid_node_type:
-            if not result_graph.has_node(src):
-                result_graph.add_node(src, type=src_type)
-            if not result_graph.has_node(dst):
-                result_graph.add_node(dst, type=dst_type)
-            if not result_graph.has_edge(src, dst):
-                result_graph.add_edge(src, dst, type=edge_type)
-                if bidirection:
-                    result_graph.add_edge(dst, src, type='reverse_{}'.format(edge_type))
-    return cnt, result_graph
-
-
-node_type_list = []
-edge_type_list = []
-node_type_dict = {}
-edge_type_dict = {}
-
-
-def format_graph(g, name):
-    new_g = nx.DiGraph()
-    node_map = {}
-    node_cnt = 0
-    for n in g.nodes:
-        node_map[n] = node_cnt
-        new_g.add_node(node_cnt, type=g.nodes[n]['type'])
-        node_cnt += 1
-    for e in g.edges:
-        new_g.add_edge(node_map[e[0]], node_map[e[1]], type=g.edges[e]['type'])
-    for n in new_g.nodes:
-        node_type = new_g.nodes[n]['type']
-        if not node_type in node_type_dict:
-            node_type_list.append(node_type)
-            node_type_dict[node_type] = 1
-        else:
-            node_type_dict[node_type] += 1
-    for e in new_g.edges:
-        edge_type = new_g.edges[e]['type']
-        if not edge_type in edge_type_dict:
-            edge_type_list.append(edge_type)
-            edge_type_dict[edge_type] = 1
-        else:
-            edge_type_dict[edge_type] += 1
-    for n in new_g.nodes:
-        new_g.nodes[n]['type'] = node_type_list.index(new_g.nodes[n]['type'])
-    for e in new_g.edges:
-        new_g.edges[e]['type'] = edge_type_list.index(new_g.edges[e]['type'])
-    with open('{}.json'.format(name), 'w', encoding='utf-8') as f:
-        json.dump(nx.node_link_data(new_g), f)
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description='Convert CamFlow JSON to Unicorn edgelist')
-    args = parser.parse_args()
-    args.stats = False
-    args.verbose = False
-    args.jiffies = False
-    args.input = '../data/wget/raw/'
-    args.output = '../data/wget/processed/'
-    args.final_output = '../data/wget/final/'
-    args.noencode = False
-    if not os.path.exists(args.input):
-        os.mkdir(args.input)
-    if not os.path.exists(args.output):
-        os.mkdir(args.output)
-    if not os.path.exists(args.final_output):
-        os.mkdir(args.final_output)
-    CONSOLE_ARGUMENTS = args
-
-    if args.verbose:
-        logging.basicConfig(filename=args.log, level=logging.DEBUG)
-    cnt = 0
-    for fname in os.listdir(args.input):
-        cnt += 1
-        node_map = dict()
-        parse_all_nodes(args.input + '/{}'.format(fname), node_map)
-        total_edges = parse_all_edges(args.input + '/{}'.format(fname), args.output + '/{}.log'.format(cnt), node_map,
-                                      args.noencode)
-        if args.stats:
-            total_nodes = len(node_map)
-            stats = open(args.stats_file + '/{}.log'.format(cnt), "a+")
-            stats.write("{},{},{}\n".format(args.input + '/{}'.format(fname), total_nodes, total_edges))
-
-    bidirection = False
-    threshold = 10000000 # infinity
-    interaction_dict = []
-    graph_cnt = 0
-    result_graphs = []
-    input = args.output
-    base = args.final_output
-
-    line_cnt = 0
-    for i in tqdm(range(cnt)):
-        single_cnt, result_graph = process_graph('{}{}.log'.format(input, i + 1), threshold)
-        format_graph(result_graph, '{}{}'.format(base, i))
-        line_cnt += single_cnt
-
-    print(line_cnt // 150)
-    print(len(node_type_list))
-    print(node_type_dict)
-    print(len(edge_type_list))
-    print(edge_type_dict)
-
diff --git a/utils/.gitkeep b/utils/.gitkeep
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/utils/__pycache__/config.cpython-310.pyc b/utils/__pycache__/config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca2a48894d4871479b958f92c0ae288f5c3e6cad
Binary files /dev/null and b/utils/__pycache__/config.cpython-310.pyc differ
diff --git a/utils/__pycache__/config.cpython-311.pyc b/utils/__pycache__/config.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0a2ca013c4c0826c94b817a15f5aeaadf3fbf021
Binary files /dev/null and b/utils/__pycache__/config.cpython-311.pyc differ
diff --git a/utils/__pycache__/configJupyter.cpython-311.pyc b/utils/__pycache__/configJupyter.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..470b992f5214101b8dcf54596f455e87c14e0d9b
Binary files /dev/null and b/utils/__pycache__/configJupyter.cpython-311.pyc differ
diff --git a/utils/__pycache__/dataloader.cpython-310.pyc b/utils/__pycache__/dataloader.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d87ef800e3349966c54f8aff0158c342c15ce957
Binary files /dev/null and b/utils/__pycache__/dataloader.cpython-310.pyc differ
diff --git a/utils/__pycache__/dataloader.cpython-311.pyc b/utils/__pycache__/dataloader.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..435191a8fbab8cf94e7ec9ff76aa01f782a87157
Binary files /dev/null and b/utils/__pycache__/dataloader.cpython-311.pyc differ
diff --git a/utils/__pycache__/loaddata.cpython-311.pyc b/utils/__pycache__/loaddata.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..516f605481c393aef9331a111ab596d8f327e0ee
Binary files /dev/null and b/utils/__pycache__/loaddata.cpython-311.pyc differ
diff --git a/utils/__pycache__/poolers.cpython-310.pyc b/utils/__pycache__/poolers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..441c02bbdd690fbfccc5acf786b9ded0f067bd8d
Binary files /dev/null and b/utils/__pycache__/poolers.cpython-310.pyc differ
diff --git a/utils/__pycache__/poolers.cpython-311.pyc b/utils/__pycache__/poolers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4464f0cddb9bf9523f2ee1a43193ac41ebdb2f73
Binary files /dev/null and b/utils/__pycache__/poolers.cpython-311.pyc differ
diff --git a/utils/__pycache__/utils.cpython-310.pyc b/utils/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10658e8f670d580962f3f3982605978f526561a7
Binary files /dev/null and b/utils/__pycache__/utils.cpython-310.pyc differ
diff --git a/utils/__pycache__/utils.cpython-311.pyc b/utils/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16c0ff7b920416d5844cb0abf8995663dd0bf741
Binary files /dev/null and b/utils/__pycache__/utils.cpython-311.pyc differ
diff --git a/utils/config.py b/utils/config.py
index ee14920f4b2e36dfcd0d7fddca100572837ade20..75b33c7e63bcbcab20114e07ab7075922c59417b 100644
--- a/utils/config.py
+++ b/utils/config.py
@@ -1,13 +1,23 @@
+import argparse
+
+
 def build_args():
-    args = {}
-    args["dataset"] ="wget"
-    args["device"]=-1
-    args["lr"]=0.001
-    args["weight_decay"]=5e-4
-    args["negative_slope"]=0.2
-    args["mask_rate"]=0.5
-    args["alpha_l"]=3
-    args["optimizer"]="adam"
-    args["loss_fn"]='sce'
-    args["pooling"] = "mean"
-    return args
+    parser = argparse.ArgumentParser(description="MAGIC")
+    parser.add_argument("--device", type=int, default=-1)
+    parser.add_argument("--lr", type=float, default=0.0001,
+                        help="learning rate")
+    parser.add_argument("--weight_decay", type=float, default=5e-4,
+                        help="weight decay")
+    parser.add_argument("--negative_slope", type=float, default=0.2,
+                        help="the negative slope of leaky relu for GAT")
+    parser.add_argument("--mask_rate", type=float, default=0.5)
+    parser.add_argument("--alpha_l", type=float, default=3, help="`pow`inddex for `sce` loss")
+    parser.add_argument("--optimizer", type=str, default="adam")
+    parser.add_argument("--loss_fn", type=str, default='sce')
+    parser.add_argument("--pooling", type=str, default="mean")
+    parser.add_argument("--run_id", type=str, default="fedgraphnn_cs_ch")
+    parser.add_argument("--cf", type=str, default="fedml_config.yaml")
+    parser.add_argument("--rank", type=int, default=0)
+    parser.add_argument("--role", type=str, default="server")
+    args = parser.parse_args()
+    return args
\ No newline at end of file
diff --git a/utils/configJupyter.py b/utils/configJupyter.py
new file mode 100644
index 0000000000000000000000000000000000000000..269779745a9c835b757ea725d201e11e5ee3e87b
--- /dev/null
+++ b/utils/configJupyter.py
@@ -0,0 +1,14 @@
+
+def build_args():
+
+    args = {}
+    args['device'] = -1 
+    args['lr'] = 0.01
+    args['weight_decay'] = 5e-4
+    args['negative_slope'] = 0.2
+    args['mask_rate'] = 0.5
+    args['alpha_l'] = 3
+    args['optimizer'] = 'adam'
+    args['loss_fn'] = "sce"
+    args['pooling'] = "mean"
+    return args
\ No newline at end of file
diff --git a/utils/dataloader.py b/utils/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..992c6fcacaee228330db91e920c03c6d4f2fc8ad
--- /dev/null
+++ b/utils/dataloader.py
@@ -0,0 +1,331 @@
+import os
+import random
+import networkx as nx
+import dgl
+import torch
+import pickle as pkl
+import json
+import logging
+import numpy as np
+
+path_dataset = 'D:/PFE DATASETS/'
+
+def darpa_split(name):
+    metadata = load_metadata(name)
+    n_train = metadata['n_train']
+    train_dataset = range(n_train)
+    train_labels = [0]* n_train
+    
+    
+    return (
+        train_dataset,
+        train_labels,
+        [],
+        [],
+        [],
+        []
+    )
+       
+
+def create_random_split(name, snapshots):
+    dataset = load_data(name)
+    # Random 80/10/10 split as suggested 
+    
+
+    all_idxs = list(range(len(dataset)))
+    random.shuffle(all_idxs)
+
+    train_dataset = dataset['train_index']
+    train_labels = []
+    for id in train_dataset:
+        train_labels.append(dataset['labels'][id])
+
+    val_dataset = dataset['validation_index']
+    val_labels = []
+    for id in val_dataset:
+        val_labels.append(dataset['labels'][id])
+
+    test_dataset = dataset['test_index']
+    test_labels = []
+    for id in test_dataset:
+        test_labels.append(dataset['labels'][id])
+
+
+    return (
+        train_dataset,
+        train_labels,
+        val_dataset,
+        val_labels,
+        test_dataset,
+        test_labels,
+    )
+
+
+
+def partition_data_by_sample_size(
+     client_number, name, snapshots
+):
+    if (name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']):
+        (
+            train_dataset,
+            train_labels,
+            val_dataset,
+            val_labels,
+            test_dataset,
+            test_labels,
+        ) = create_random_split(name, snapshots)
+    else:
+        (
+            train_dataset,
+            train_labels,
+            val_dataset,
+            val_labels,
+            test_dataset,
+            test_labels,
+        ) = darpa_split(name)
+        
+    num_train_samples = len(train_dataset)
+    num_val_samples = len(val_dataset)
+    num_test_samples = len(test_dataset)
+
+    train_idxs = list(range(num_train_samples))
+    val_idxs = list(range(num_val_samples))
+    test_idxs = list(range(num_test_samples))
+
+    random.shuffle(train_idxs)
+    random.shuffle(val_idxs)
+    random.shuffle(test_idxs)
+
+    partition_dicts = [None] * client_number
+
+    
+    clients_idxs_train = np.array_split(train_idxs, client_number)
+    clients_idxs_val = np.array_split(val_idxs, client_number)
+    clients_idxs_test = np.array_split(test_idxs, client_number)
+    
+    labels_of_all_clients = []
+    for client in range(client_number):
+        client_train_idxs = clients_idxs_train[client]
+        client_val_idxs = clients_idxs_val[client]
+        client_test_idxs = clients_idxs_test[client]
+
+        train_dataset_client = [
+            train_dataset[idx] for idx in client_train_idxs
+        ]
+        train_labels_client = [train_labels[idx] for idx in client_train_idxs]
+        labels_of_all_clients.append(train_labels_client)
+
+        val_dataset_client = [val_dataset[idx] for idx in client_val_idxs]
+        val_labels_client = [val_labels[idx] for idx in client_val_idxs]
+
+        test_dataset_client = [test_dataset[idx] for idx in client_test_idxs]
+        test_labels_client = [test_labels[idx] for idx in client_test_idxs]
+
+
+        partition_dict = {
+            "train": train_dataset_client,
+            "val": val_dataset_client,
+            "test": test_dataset_client,
+        }
+
+        partition_dicts[client] = partition_dict
+    global_data_dict = {
+        "train": train_dataset,
+        "val": val_dataset,
+        "test": test_dataset,
+    }
+
+    return global_data_dict, partition_dicts
+
+def load_partition_data(
+    client_number,
+    name,
+    snapshots,
+    global_test=True,
+):
+    global_data_dict, partition_dicts = partition_data_by_sample_size(
+        client_number, name, snapshots
+    )
+
+    data_local_num_dict = dict()
+    train_data_local_dict = dict()
+    val_data_local_dict = dict()
+    test_data_local_dict = dict()
+
+  
+
+    # IT IS VERY IMPORTANT THAT THE BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE.
+    train_data_global =  global_data_dict["train"]
+    val_data_global = global_data_dict["val"]
+    test_data_global = global_data_dict["test"]
+    train_data_num = len(global_data_dict["train"])
+    val_data_num = len(global_data_dict["val"])
+    test_data_num = len(global_data_dict["test"])
+
+    for client in range(client_number):
+        train_dataset_client = partition_dicts[client]["train"]
+        val_dataset_client = partition_dicts[client]["val"]
+        test_dataset_client = partition_dicts[client]["test"]
+
+        data_local_num_dict[client] = len(train_dataset_client)
+        train_data_local_dict[client] = train_dataset_client,
+ 
+        val_data_local_dict[client] = val_dataset_client
+  
+        test_data_local_dict[client] = (
+            test_data_global
+            if global_test
+            else test_dataset_client
+                
+        )
+
+        logging.info(
+            "Client idx = {}, local sample number = {}".format(
+                client, len(train_dataset_client)
+            )
+        )
+
+    return (
+        train_data_num,
+        val_data_num,
+        test_data_num,
+        train_data_global,
+        val_data_global,
+        test_data_global,
+        data_local_num_dict,
+        train_data_local_dict,
+        val_data_local_dict,
+        test_data_local_dict,
+    )
+
+
+
+
+
+
+
+def preload_entity_level_dataset(name):
+    path = path_dataset + name
+    if os.path.exists(path + '/metadata.json'):
+        pass
+    else:
+    
+        malicious = pkl.load(open(path + '/malicious.pkl', 'rb'))
+
+        n_train = len(os.listdir(path + '/train'))
+        n_test = len(os.listdir(path + '/test'))
+            
+        g = pkl.load(open(path + '/train/graph0/graph0.pkl', 'rb'))
+
+        node_feature_dim = len(g.ndata['attr'][0])
+        edge_feature_dim = len(g.edata['attr'][0])
+
+        metadata = {
+            'node_feature_dim': node_feature_dim,
+            'edge_feature_dim': edge_feature_dim,
+            'malicious': malicious,
+            'n_train': n_train,
+            'n_test': n_test
+        }
+        with open(path + '/metadata.json', 'w', encoding='utf-8') as f:
+            json.dump(metadata, f)
+
+
+
+def load_metadata(name):
+    preload_entity_level_dataset(name)
+    with open( path_dataset + name + '/metadata.json', 'r', encoding='utf-8') as f:
+        metadata = json.load(f)
+    return metadata
+
+
+def load_entity_level_dataset(name, t, n, snapshot, device):
+    preload_entity_level_dataset(name)
+    graphs = []
+    for i in range(snapshot):
+        with open(path_dataset + name + '/' + t + '/graph{}/graph{}.pkl'.format(n, str(i)), 'rb') as f:
+            graphs.append(pkl.load(f).to(device))
+    return graphs
+
+
+def get_labels(name):
+    if (name=="wget" ):
+        return [1] * 25 + [0] * 125
+    elif (name=="streamspot"):
+        return [0] * 300 + [1] * 100 + [0] * 200
+    elif (name == 'SC2'):
+        return [0] * 125 + [1] * 25 
+    elif (name == 'Unicorn-Cadets'):
+        return [0] * 109 + [1] * 3
+    elif (name == 'wget-long'):
+        return [0] * 100 + [1] * 5
+    elif (name == 'clearscope-e3'):
+        return [0] * 44 + [1] * 50
+    
+def load_data(name):
+    if name == "wget":
+        n, n_dim, e_dim = 150, 14, 4
+        full_dataset_index = list(range(n))
+        train_dataset = list(range(50, 150))
+        validation_dataset = list(range(50))
+        test_dataset = list(range(50))
+    elif name == "streamspot":
+        n, n_dim, e_dim = 600, 8, 26
+        full_dataset_index = list(range(n))
+        train_dataset = list(range(300)) 
+        validation_dataset = list(range(300, 350)) + list(range(500,550))
+        test_dataset = list(range(300, 400)) + list(range(400,500))+ list(range(500,600))
+    elif name == 'SC2':
+        n_dim = len(pkl.load(open(path_dataset + 'SC2/node.pkl', 'rb')).keys())
+        e_dim = len(pkl.load(open(path_dataset + 'SC2/edge.pkl', 'rb')).keys())
+        n, full_dataset_index = 150, list(range(150))
+        train_dataset = list(range(100))
+        validation_dataset = list(range(100, 150))
+        test_dataset = list(range(100, 150))
+    elif name in ['Unicorn-Cadets', 'wget-long', 'clearscope-e3']:
+        n_dim = len(pkl.load(open(path_dataset + '{}/node.pkl'.format(name), 'rb')).keys())
+        e_dim = len(pkl.load(open(path_dataset + '{}/edge.pkl'.format(name), 'rb')).keys())
+        if name == 'Unicorn-Cadets':
+            n, train_dataset = 112, list(range(70))
+        elif name == 'wget-long':
+            n, train_dataset = 105, list(range(70))
+        else:
+            n, train_dataset = 94, list(range(30))
+        full_dataset_index = list(range(n))
+        validation_dataset = list(range(train_dataset[-1], n))
+        test_dataset = validation_dataset
+    return {'dataset': full_dataset_index,
+            'train_index': train_dataset,
+            'test_index': test_dataset,
+            'validation_index': validation_dataset,
+            'full_index': full_dataset_index,
+            'n_feat': n_dim,
+            'e_feat': e_dim,
+            'labels': get_labels(name)}
+            
+
+            
+def load_graph(id, name ,device):
+    graphs = []
+
+    if (name == "wget"):
+        path = path_dataset + 'wget/cache/' + 'graph{}'.format(str(id))
+    elif (name == "streamspot"):
+        path = path_dataset + 'streamspot/cache/' + 'graph{}'.format(str(id))
+    elif (name == "SC2"):
+        if (id < 125): path = path_dataset + 'SC2/cache/benign/' + 'graph{}'.format(str(id))
+        else: path = path_dataset + 'SC2/cache/attack/' + 'graph{}'.format(str(id - 125))
+    elif (name == 'Unicorn-Cadets'):
+        if (id < 109): path = path_dataset + 'Unicorn-Cadets/cache/benign/' + 'graph{}'.format(str(id))
+        else: path = path_dataset + 'Unicorn-Cadets/cache/attack/' + 'graph{}'.format(str(id - 109))
+    elif (name == 'wget-long'):
+        if (id < 100): path = path_dataset + 'wget-long/cache/benign/' + 'graph{}'.format(str(id))
+        else: path = path_dataset + 'wget-long/cache/attack/' + 'graph{}'.format(str(id - 100))
+    elif (name == 'clearscope-e3'):
+        if (id < 44): path = path_dataset + 'clearscope-e3/cache/benign/' + 'graph{}'.format(str(id))
+        else: path = path_dataset + 'clearscope-e3/cache/attack/' + 'graph{}'.format(str(id - 44))
+
+    for fname in os.listdir(path):
+        graphs.append(pkl.load(open(path + '/' + fname, 'rb')).to(device))
+
+    return graphs
\ No newline at end of file
diff --git a/utils/loaddata.py b/utils/loaddata.py
deleted file mode 100644
index ca48e0fd60aa69712f48c06f8083b37f679cf797..0000000000000000000000000000000000000000
--- a/utils/loaddata.py
+++ /dev/null
@@ -1,207 +0,0 @@
-import pickle as pkl
-import time
-import torch.nn.functional as F
-import dgl
-import networkx as nx
-import json
-from tqdm import tqdm
-import os
-
-
-
-class StreamspotDataset(dgl.data.DGLDataset):
-    def process(self):
-        pass
-
-    def __init__(self, name):
-        super(StreamspotDataset, self).__init__(name=name)
-        if name == 'streamspot':
-            path = './data/streamspot'
-            num_graphs = 600
-            self.graphs = []
-            self.labels = []
-            print('Loading {} dataset...'.format(name))
-            for i in tqdm(range(num_graphs)):
-                idx = i
-                g = dgl.from_networkx(
-                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))),
-                    node_attrs=['type'],
-                    edge_attrs=['type']
-                )
-                self.graphs.append(g)
-                if 300 <= idx <= 399:
-                    self.labels.append(1)
-                else:
-                    self.labels.append(0)
-        else:
-            raise NotImplementedError
-
-    def __getitem__(self, i):
-        return self.graphs[i], self.labels[i]
-
-    def __len__(self):
-        return len(self.graphs)
-
-
-class WgetDataset(dgl.data.DGLDataset):
-    def process(self):
-        pass
-
-    def __init__(self, name):
-        super(WgetDataset, self).__init__(name=name)
-        if name == 'wget':
-            pathattack = '/data/wget/finalattack'
-            pathbenin = 'data/wget/finalbenin'
-            num_graphs_benin = 125
-            num_graphs_attack = 25
-            self.graphs = []
-            self.labels = []
-            print('Loading {} dataset...'.format(name))
-            for i in tqdm(range(num_graphs_benin)):
-                idx = i
-                g = dgl.from_networkx(
-                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))),
-                    node_attrs=['type'],
-                    edge_attrs=['type']
-                )
-                self.graphs.append(g)
-                self.labels.append(0)
-                
-            for i in tqdm(range(num_graphs_attack)):
-                idx = i
-                g = dgl.from_networkx(
-                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))),
-                    node_attrs=['type'],
-                    edge_attrs=['type']
-                )
-                self.graphs.append(g)
-                self.labels.append(1)
-        else:
-            raise NotImplementedError
-
-    def __getitem__(self, i):
-        return self.graphs[i], self.labels[i]
-
-    def __len__(self):
-        return len(self.graphs)
-
-
-def load_rawdata(name):
-    if name == 'streamspot':
-        path = './data/streamspot'
-        if os.path.exists(path + '/graphs.pkl'):
-            print('Loading processed {} dataset...'.format(name))
-            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
-        else:
-            raw_data = StreamspotDataset(name)
-            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
-    elif name == 'wget':
-        path = './data/wget'
-        if os.path.exists(path + '/graphs.pkl'):
-            print('Loading processed {} dataset...'.format(name))
-            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
-        else:
-            raw_data = WgetDataset(name)
-            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
-    else:
-        raise NotImplementedError
-    return raw_data
-
-
-def load_batch_level_dataset(dataset_name):
-    dataset = load_rawdata(dataset_name)
-    graph, _ = dataset[0]
-    node_feature_dim = 0
-    for g, _ in dataset:
-        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
-    edge_feature_dim = 0
-    for g, _ in dataset:
-        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
-    node_feature_dim += 1
-    edge_feature_dim += 1
-    full_dataset = [i for i in range(len(dataset))]
-    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
-    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
-
-    return {'dataset': dataset,
-            'train_index': train_dataset,
-            'full_index': full_dataset,
-            'n_feat': node_feature_dim,
-            'e_feat': edge_feature_dim}
-
-
-def transform_graph(g, node_feature_dim, edge_feature_dim):
-    new_g = g.clone()
-    new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float()
-    new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float()
-    return new_g
-
-
-def preload_entity_level_dataset(path):
-    path = './data/' + path
-    if os.path.exists(path + '/metadata.json'):
-        pass
-    else:
-        print('transforming')
-        train_gs = [dgl.from_networkx(
-            nx.node_link_graph(g),
-            node_attrs=['type'],
-            edge_attrs=['type']
-        ) for g in pkl.load(open(path + '/train.pkl', 'rb'))]
-        print('transforming')
-        test_gs = [dgl.from_networkx(
-            nx.node_link_graph(g),
-            node_attrs=['type'],
-            edge_attrs=['type']
-        ) for g in pkl.load(open(path + '/test.pkl', 'rb'))]
-        malicious = pkl.load(open(path + '/malicious.pkl', 'rb'))
-
-        node_feature_dim = 0
-        for g in train_gs:
-            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
-        for g in test_gs:
-            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
-        node_feature_dim += 1
-        edge_feature_dim = 0
-        for g in train_gs:
-            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
-        for g in test_gs:
-            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
-        edge_feature_dim += 1
-        result_test_gs = []
-        for g in test_gs:
-            g = transform_graph(g, node_feature_dim, edge_feature_dim)
-            result_test_gs.append(g)
-        result_train_gs = []
-        for g in train_gs:
-            g = transform_graph(g, node_feature_dim, edge_feature_dim)
-            result_train_gs.append(g)
-        metadata = {
-            'node_feature_dim': node_feature_dim,
-            'edge_feature_dim': edge_feature_dim,
-            'malicious': malicious,
-            'n_train': len(result_train_gs),
-            'n_test': len(result_test_gs)
-        }
-        with open(path + '/metadata.json', 'w', encoding='utf-8') as f:
-            json.dump(metadata, f)
-        for i, g in enumerate(result_train_gs):
-            with open(path + '/train{}.pkl'.format(i), 'wb') as f:
-                pkl.dump(g, f)
-        for i, g in enumerate(result_test_gs):
-            with open(path + '/test{}.pkl'.format(i), 'wb') as f:
-                pkl.dump(g, f)
-
-
-def load_metadata(path):
-    preload_entity_level_dataset(path)
-    with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f:
-        metadata = json.load(f)
-    return metadata
-
-
-def load_entity_level_dataset(path, t, n):
-    preload_entity_level_dataset(path)
-    with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f:
-        data = pkl.load(f)
-    return data
diff --git a/utils/streamspot_parser.py b/utils/streamspot_parser.py
deleted file mode 100644
index 07687828ea5ecf56f9859955e599afb584826f1f..0000000000000000000000000000000000000000
--- a/utils/streamspot_parser.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import networkx as nx
-from tqdm import tqdm
-import json
-raw_path = '../data/streamspot/'
-
-NUM_GRAPHS = 600
-node_type_dict = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
-edge_type_dict = ['i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
-                  'q', 't', 'u', 'v', 'w', 'y', 'z', 'A', 'C', 'D', 'E', 'G']
-node_type_set = set(node_type_dict)
-edge_type_set = set(edge_type_dict)
-print(raw_path)
-count_graph = 0
-with open(raw_path + 'all.tsv', 'r', encoding='utf-8') as f:
-    print("Reading")
-    lines = f.readlines()
-    print("starting")
-    g = nx.DiGraph()
-    node_map = {}
-    count_node = 0
-    for line in tqdm(lines):
-        src, src_type, dst, dst_type, etype, graph_id = line.strip('\n').split('\t')
-        graph_id = int(graph_id)
-        if src_type not in node_type_set or dst_type not in node_type_set:
-            continue
-        if etype not in edge_type_set:
-            continue
-        if graph_id != count_graph:
-            count_graph += 1
-            for n in g.nodes():
-                g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
-            for e in g.edges():
-                g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
-            f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
-            json.dump(nx.node_link_data(g), f1)
-            assert graph_id == count_graph
-            g = nx.DiGraph()
-            count_node = 0
-        if src not in node_map:
-            node_map[src] = count_node
-            g.add_node(count_node, type=src_type)
-            count_node += 1
-        if dst not in node_map:
-            node_map[dst] = count_node
-            g.add_node(count_node, type=dst_type)
-            count_node += 1
-        if not g.has_edge(node_map[src], node_map[dst]):
-            g.add_edge(node_map[src], node_map[dst], type=etype)
-    count_graph += 1
-    for n in g.nodes():
-        g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
-    for e in g.edges():
-        g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
-    f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
-    json.dump(nx.node_link_data(g), f1)
diff --git a/utils/trace_parser.py b/utils/trace_parser.py
deleted file mode 100644
index 76a91d424a3eea0bbec87ffb635a0d98b27d72f7..0000000000000000000000000000000000000000
--- a/utils/trace_parser.py
+++ /dev/null
@@ -1,288 +0,0 @@
-import argparse
-import json
-import os
-import random
-import re
-
-from tqdm import tqdm
-import networkx as nx
-import pickle as pkl
-
-
-node_type_dict = {}
-edge_type_dict = {}
-node_type_cnt = 0
-edge_type_cnt = 0
-
-metadata = {
-    'trace':{
-        'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'],
-        'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4']
-    },
-    'theia':{
-            'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'],
-            'test': ['ta1-theia-e3-official-6r.json.8']
-    },
-    'cadets':{
-            'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'],
-            'test': ['ta1-cadets-e3-official-2.json']
-    }
-}
-
-
-pattern_uuid = re.compile(r'uuid\":\"(.*?)\"')
-pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
-pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
-pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
-pattern_type = re.compile(r'type\":\"(.*?)\"')
-pattern_time = re.compile(r'timestampNanos\":(.*?),')
-pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"')
-pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"')
-pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"')
-
-adversarial = 'None'
-
-
-def read_single_graph(dataset, malicious, path, test=False):
-    global node_type_cnt, edge_type_cnt
-    g = nx.DiGraph()
-    print('converting {} ...'.format(path))
-    path = '../data/{}/'.format(dataset) + path + '.txt'
-    f = open(path, 'r')
-    lines = []
-    for l in f.readlines():
-        split_line = l.split('\t')
-        src, src_type, dst, dst_type, edge_type, ts = split_line
-        ts = int(ts)
-        if not test:
-            if src in malicious or dst in malicious:
-                if src in malicious and src_type != 'MemoryObject':
-                    continue
-                if dst in malicious and dst_type != 'MemoryObject':
-                    continue
-
-        if src_type not in node_type_dict:
-            node_type_dict[src_type] = node_type_cnt
-            node_type_cnt += 1
-        if dst_type not in node_type_dict:
-            node_type_dict[dst_type] = node_type_cnt
-            node_type_cnt += 1
-        if edge_type not in edge_type_dict:
-            edge_type_dict[edge_type] = edge_type_cnt
-            edge_type_cnt += 1
-        if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type:
-            lines.append([dst, src, dst_type, src_type, edge_type, ts])
-        else:
-            lines.append([src, dst, src_type, dst_type, edge_type, ts])
-    lines.sort(key=lambda l: l[5])
-
-    node_map = {}
-    node_type_map = {}
-    node_cnt = 0
-    node_list = []
-    for l in lines:
-        src, dst, src_type, dst_type, edge_type = l[:5]
-        src_type_id = node_type_dict[src_type]
-        dst_type_id = node_type_dict[dst_type]
-        edge_type_id = edge_type_dict[edge_type]
-        if src not in node_map:
-            if test:
-                if adversarial == 'MFE' or adversarial == 'MCE':
-                    if src in malicious:
-                        src_type_id = node_type_dict['FILE_OBJECT_FILE']
-            if not test:
-                if adversarial == 'BFP':
-                    if src not in malicious:
-                        i = random.randint(1, 20)
-                        if i == 1:
-                            src_type_id = node_type_dict['NetFlowObject']
-            node_map[src] = node_cnt
-            g.add_node(node_cnt, type=src_type_id)
-            node_list.append(src)
-            node_type_map[src] = src_type
-            node_cnt += 1
-        if dst not in node_map:
-            if test:
-                if adversarial == 'MFE' or adversarial == 'MCE':
-                    if dst in malicious:
-                        dst_type_id = node_type_dict['FILE_OBJECT_FILE']
-            if not test:
-                if adversarial == 'BFP':
-                    if dst not in malicious:
-                        i = random.randint(1, 20)
-                        if i == 1:
-                            dst_type_id = node_type_dict['NetFlowObject']
-            node_map[dst] = node_cnt
-            g.add_node(node_cnt, type=dst_type_id)
-            node_type_map[dst] = dst_type
-            node_list.append(dst)
-            node_cnt += 1
-        if not g.has_edge(node_map[src], node_map[dst]):
-            g.add_edge(node_map[src], node_map[dst], type=edge_type_id)
-    if (adversarial == 'MSE' or adversarial == 'MCE') and test:
-        for i, node in enumerate(node_list):
-            if node in malicious:
-                while True:
-                    another_node = random.choice(node_list)
-                    if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS':
-                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE'])
-                        print(node, another_node)
-                        break
-                    if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE':
-                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ'])
-                        print(node, another_node)
-                        break
-                    if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS':
-                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT'])
-                        print(node, another_node)
-                        break
-                    if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject':
-                        break
-    return node_map, g
-
-
-def preprocess_dataset(dataset):
-    id_nodetype_map = {}
-    id_nodename_map = {}
-    for file in os.listdir('../data/{}/'.format(dataset)):
-        if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file:
-            print('reading {} ...'.format(file))
-            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
-            for line in tqdm(f):
-                if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue
-                if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue
-                if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue
-                if len(pattern_uuid.findall(line)) == 0: print(line)
-                uuid = pattern_uuid.findall(line)[0]
-                subject_type = pattern_type.findall(line)
-
-                if len(subject_type) < 1:
-                    if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line:
-                        subject_type = 'MemoryObject'
-                    if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line:
-                        subject_type = 'NetFlowObject'
-                    if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line:
-                        subject_type = 'UnnamedPipeObject'
-                else:
-                    subject_type = subject_type[0]
-
-                if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']:
-                    continue
-                id_nodetype_map[uuid] = subject_type
-                if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0:
-                    id_nodename_map[uuid] = pattern_file_name.findall(line)[0]
-                elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0:
-                    id_nodename_map[uuid] = pattern_process_name.findall(line)[0]
-                elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0:
-                    id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0]
-    for key in metadata[dataset]:
-        for file in metadata[dataset][key]:
-            if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'):
-                continue
-            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
-            fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8')
-            print('processing {} ...'.format(file))
-            for line in tqdm(f):
-                if 'com.bbn.tc.schema.avro.cdm18.Event' in line:
-                    edgeType = pattern_type.findall(line)[0]
-                    timestamp = pattern_time.findall(line)[0]
-                    srcId = pattern_src.findall(line)
-
-                    if len(srcId) == 0: continue
-                    srcId = srcId[0]
-                    if not srcId in id_nodetype_map:
-                        continue
-                    srcType = id_nodetype_map[srcId]
-                    dstId1 = pattern_dst1.findall(line)
-                    if len(dstId1) > 0 and dstId1[0] != 'null':
-                        dstId1 = dstId1[0]
-                        if not dstId1 in id_nodetype_map:
-                            continue
-                        dstType1 = id_nodetype_map[dstId1]
-                        this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str(
-                            dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
-                        fw.write(this_edge1)
-
-                    dstId2 = pattern_dst2.findall(line)
-                    if len(dstId2) > 0 and dstId2[0] != 'null':
-                        dstId2 = dstId2[0]
-                        if not dstId2 in id_nodetype_map.keys():
-                            continue
-                        dstType2 = id_nodetype_map[dstId2]
-                        this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str(
-                            dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
-                        fw.write(this_edge2)
-            fw.close()
-            f.close()
-    if len(id_nodename_map) != 0:
-        fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8')
-        json.dump(id_nodename_map, fw)
-    if len(id_nodetype_map) != 0:
-        fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8')
-        json.dump(id_nodetype_map, fw)
-
-
-def read_graphs(dataset):
-    malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset)
-    f = open(malicious_entities, 'r')
-    malicious_entities = set()
-    for l in f.readlines():
-        malicious_entities.add(l.lstrip().rstrip())
-
-    preprocess_dataset(dataset)
-    train_gs = []
-    for file in metadata[dataset]['train']:
-        _, train_g = read_single_graph(dataset, malicious_entities, file, False)
-        train_gs.append(train_g)
-    test_gs = []
-    test_node_map = {}
-    count_node = 0
-    for file in metadata[dataset]['test']:
-        node_map, test_g = read_single_graph(dataset, malicious_entities, file, True)
-        assert len(node_map) == test_g.number_of_nodes()
-        test_gs.append(test_g)
-        for key in node_map:
-            if key not in test_node_map:
-                test_node_map[key] = node_map[key] + count_node
-        count_node += test_g.number_of_nodes()
-
-    if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)):
-        with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f:
-            id_nodename_map = json.load(f)
-        with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f:
-            id_nodetype_map = json.load(f)
-        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
-        final_malicious_entities = []
-        malicious_names = []
-        for e in malicious_entities:
-            if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject':
-                final_malicious_entities.append(test_node_map[e])
-                if e in id_nodename_map:
-                    malicious_names.append(id_nodename_map[e])
-                    f.write('{}\t{}\n'.format(e, id_nodename_map[e]))
-                else:
-                    malicious_names.append(e)
-                    f.write('{}\t{}\n'.format(e, e))
-    else:
-        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
-        final_malicious_entities = []
-        malicious_names = []
-        for e in malicious_entities:
-            if e in test_node_map:
-                final_malicious_entities.append(test_node_map[e])
-                malicious_names.append(e)
-                f.write('{}\t{}\n'.format(e, e))
-
-    pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb'))
-    pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb'))
-    pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb'))
-
-
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='CDM Parser')
-    parser.add_argument("--dataset", type=str, default="trace")
-    args = parser.parse_args()
-    if args.dataset not in ['trace', 'theia', 'cadets']:
-        raise NotImplementedError
-    read_graphs(args.dataset)
-
diff --git a/utils/wget_parser.py b/utils/wget_parser.py
deleted file mode 100644
index 3ffbcffda1c5e9176cf1e5219b8c1e0ed1574055..0000000000000000000000000000000000000000
--- a/utils/wget_parser.py
+++ /dev/null
@@ -1,820 +0,0 @@
-import argparse
-import json
-import os
-
-import xxhash
-import tqdm
-import logging
-import networkx as nx
-from tqdm import tqdm
-import time
-import datetime
-
-
-valid_node_type = ['file', 'process_memory', 'task', 'mmaped_file', 'path', 'socket', 'address', 'link']
-CONSOLE_ARGUMENTS = None
-
-
-def hashgen(l):
-    """Generate a single hash value from a list. @l is a list of
-    string values, which can be properties of a node/edge. This
-    function returns a single hashed integer value."""
-    hasher = xxhash.xxh64()
-    for e in l:
-        hasher.update(e)
-    return hasher.intdigest()
-
-
-def parse_nodes(json_string, node_map):
-    """Parse a CamFlow JSON string that may contain nodes ("activity" or "entity").
-    Parsed nodes populate @node_map, which is a dictionary that maps the node's UID,
-    which is assigned by CamFlow to uniquely identify a node object, to a hashed
-    value (in str) which represents the 'type' of the node. """
-    json_object = None
-    try:
-        # use "ignore" if non-decodeable exists in the @json_string
-        json_object = json.loads(json_string)
-    except Exception as e:
-        print("Exception ({}) occurred when parsing a node in JSON:".format(e))
-        print(json_string)
-        exit(1)
-    if "activity" in json_object:
-        activity = json_object["activity"]
-        for uid in activity:
-            if not uid in node_map:  # only parse unseen nodes
-                if "prov:type" not in activity[uid]:
-                    # a node must have a type.
-                    # record this issue if logging is turned on
-                    if CONSOLE_ARGUMENTS.verbose:
-                        logging.debug("skipping a problematic activity node with no 'prov:type': {}".format(uid))
-                else:
-                    node_map[uid] = activity[uid]["prov:type"]
-
-    if "entity" in json_object:
-        entity = json_object["entity"]
-        for uid in entity:
-            if not uid in node_map:
-                if "prov:type" not in entity[uid]:
-                    if CONSOLE_ARGUMENTS.verbose:
-                        logging.debug("skipping a problematic entity node with no 'prov:type': {}".format(uid))
-                else:
-                    node_map[uid] = entity[uid]["prov:type"]
-
-
-def parse_all_nodes(filename, node_map):
-    """Parse all nodes in CamFlow data. @filename is the file path of
-    the CamFlow data to parse. @node_map contains the mappings of all
-    CamFlow nodes to their hashed attributes. """
-    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing nodes in CamFlow data from {}'.format(filename)
-    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
-    with open(filename, 'r') as f:
-        # each line in CamFlow data could contain multiple
-        # provenance nodes, we call @parse_nodes routine.
-        for line in f:
-            pb.update()  # for progress tracking
-            parse_nodes(line, node_map)
-    f.close()
-    pb.close()
-
-
-def parse_all_edges(inputfile, outputfile, node_map, noencode):
-    """Parse all edges (including their timestamp) from CamFlow data file @inputfile
-    to an @outputfile. Before this function is called, parse_all_nodes should be called
-    to populate the @node_map for all nodes in the CamFlow file. If @noencode is set,
-    we do not hash the nodes' original UUIDs generated by CamFlow to integers. This
-    function returns the total number of valid edge parsed from CamFlow dataset.
-
-    The output edgelist has the following format for each line, if -s is not set:
-        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>
-    If -s is set, each line would look like:
-        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>:<timestamp_stats>"""
-    total_edges = 0
-    smallest_timestamp = None
-    # scan through the entire file to find the smallest timestamp from all the edges.
-    # this step is only needed if we need to add some statistical information.
-    if CONSOLE_ARGUMENTS.stats:
-        description = '\x1b[6;30;42m[STATUS]\x1b[0m Scanning edges in CamFlow data from {}'.format(inputfile)
-        pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
-        with open(inputfile, 'r') as f:
-            for line in f:
-                pb.update()
-                json_object = json.loads(line)
-
-                if "used" in json_object:
-                    used = json_object["used"]
-                    for uid in used:
-                        if "prov:type" not in used[uid]:
-                            continue
-                        if "cf:date" not in used[uid]:
-                            continue
-                        if "prov:entity" not in used[uid]:
-                            continue
-                        if "prov:activity" not in used[uid]:
-                            continue
-                        srcUUID = used[uid]["prov:entity"]
-                        dstUUID = used[uid]["prov:activity"]
-                        if srcUUID not in node_map:
-                            continue
-                        if dstUUID not in node_map:
-                            continue
-                        timestamp_str = used[uid]["cf:date"]
-                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                        if smallest_timestamp == None or ts < smallest_timestamp:
-                            smallest_timestamp = ts
-
-                if "wasGeneratedBy" in json_object:
-                    wasGeneratedBy = json_object["wasGeneratedBy"]
-                    for uid in wasGeneratedBy:
-                        if "prov:type" not in wasGeneratedBy[uid]:
-                            continue
-                        if "cf:date" not in wasGeneratedBy[uid]:
-                            continue
-                        if "prov:entity" not in wasGeneratedBy[uid]:
-                            continue
-                        if "prov:activity" not in wasGeneratedBy[uid]:
-                            continue
-                        srcUUID = wasGeneratedBy[uid]["prov:activity"]
-                        dstUUID = wasGeneratedBy[uid]["prov:entity"]
-                        if srcUUID not in node_map:
-                            continue
-                        if dstUUID not in node_map:
-                            continue
-                        timestamp_str = wasGeneratedBy[uid]["cf:date"]
-                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                        if smallest_timestamp == None or ts < smallest_timestamp:
-                            smallest_timestamp = ts
-
-                if "wasInformedBy" in json_object:
-                    wasInformedBy = json_object["wasInformedBy"]
-                    for uid in wasInformedBy:
-                        if "prov:type" not in wasInformedBy[uid]:
-                            continue
-                        if "cf:date" not in wasInformedBy[uid]:
-                            continue
-                        if "prov:informant" not in wasInformedBy[uid]:
-                            continue
-                        if "prov:informed" not in wasInformedBy[uid]:
-                            continue
-                        srcUUID = wasInformedBy[uid]["prov:informant"]
-                        dstUUID = wasInformedBy[uid]["prov:informed"]
-                        if srcUUID not in node_map:
-                            continue
-                        if dstUUID not in node_map:
-                            continue
-                        timestamp_str = wasInformedBy[uid]["cf:date"]
-                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                        if smallest_timestamp == None or ts < smallest_timestamp:
-                            smallest_timestamp = ts
-
-                if "wasDerivedFrom" in json_object:
-                    wasDerivedFrom = json_object["wasDerivedFrom"]
-                    for uid in wasDerivedFrom:
-                        if "prov:type" not in wasDerivedFrom[uid]:
-                            continue
-                        if "cf:date" not in wasDerivedFrom[uid]:
-                            continue
-                        if "prov:usedEntity" not in wasDerivedFrom[uid]:
-                            continue
-                        if "prov:generatedEntity" not in wasDerivedFrom[uid]:
-                            continue
-                        srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
-                        dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
-                        if srcUUID not in node_map:
-                            continue
-                        if dstUUID not in node_map:
-                            continue
-                        timestamp_str = wasDerivedFrom[uid]["cf:date"]
-                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                        if smallest_timestamp == None or ts < smallest_timestamp:
-                            smallest_timestamp = ts
-
-                if "wasAssociatedWith" in json_object:
-                    wasAssociatedWith = json_object["wasAssociatedWith"]
-                    for uid in wasAssociatedWith:
-                        if "prov:type" not in wasAssociatedWith[uid]:
-                            continue
-                        if "cf:date" not in wasAssociatedWith[uid]:
-                            continue
-                        if "prov:agent" not in wasAssociatedWith[uid]:
-                            continue
-                        if "prov:activity" not in wasAssociatedWith[uid]:
-                            continue
-                        srcUUID = wasAssociatedWith[uid]["prov:agent"]
-                        dstUUID = wasAssociatedWith[uid]["prov:activity"]
-                        if srcUUID not in node_map:
-                            continue
-                        if dstUUID not in node_map:
-                            continue
-                        timestamp_str = wasAssociatedWith[uid]["cf:date"]
-                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                        if smallest_timestamp == None or ts < smallest_timestamp:
-                            smallest_timestamp = ts
-        f.close()
-        pb.close()
-
-    # we will go through the CamFlow data (again) and output edgelist to a file
-    output = open(outputfile, "w+")
-    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in CamFlow data from {}'.format(inputfile)
-    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
-    with open(inputfile, 'r') as f:
-        for line in f:
-            pb.update()
-            json_object = json.loads(line)
-
-            if "used" in json_object:
-                used = json_object["used"]
-                for uid in used:
-                    if "prov:type" not in used[uid]:
-                        # an edge must have a type; if not,
-                        # we will have to skip the edge. Log
-                        # this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (used) record without type: {}".format(uid))
-                        continue
-                    else:
-                        edgetype = "used"
-                    # cf:id is used as logical timestamp to order edges
-                    if "cf:id" not in used[uid]:
-                        # an edge must have a logical timestamp;
-                        # if not we will have to skip the edge.
-                        # Log this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (used) record without logical timestamp: {}".format(uid))
-                        continue
-                    else:
-                        timestamp = used[uid]["cf:id"]
-                    if "prov:entity" not in used[uid]:
-                        # an edge's source node must exist;
-                        # if not, we will have to skip the
-                        # edge. Log this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug(
-                                "edge (used/{}) record without source UUID: {}".format(used[uid]["prov:type"], uid))
-                        continue
-                    if "prov:activity" not in used[uid]:
-                        # an edge's destination node must exist;
-                        # if not, we will have to skip the edge.
-                        # Log this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug(
-                                "edge (used/{}) record without destination UUID: {}".format(used[uid]["prov:type"],
-                                                                                            uid))
-                        continue
-                    srcUUID = used[uid]["prov:entity"]
-                    dstUUID = used[uid]["prov:activity"]
-                    # both source and destination node must
-                    # exist in @node_map; if not, we will
-                    # have to skip the edge. Log this issue
-                    # if verbose is set.
-                    if srcUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug(
-                                "edge (used/{}) record with an unseen srcUUID: {}".format(used[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        srcVal = node_map[srcUUID]
-                    if dstUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug(
-                                "edge (used/{}) record with an unseen dstUUID: {}".format(used[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        dstVal = node_map[dstUUID]
-                    if "cf:date" not in used[uid]:
-                        # an edge must have a timestamp; if
-                        # not, we will have to skip the edge.
-                        # Log this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (used) record without timestamp: {}".format(uid))
-                        continue
-                    else:
-                        # we only record @adjusted_ts if we need
-                        # to record stats of CamFlow dataset.
-                        if CONSOLE_ARGUMENTS.stats:
-                            ts_str = used[uid]["cf:date"]
-                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                            adjusted_ts = ts - smallest_timestamp
-                    if "cf:jiffies" not in used[uid]:
-                        # an edge must have a jiffies timestamp; if
-                        # not, we will have to skip the edge.
-                        # Log this issue if verbose is set.
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (used) record without jiffies: {}".format(uid))
-                        continue
-                    else:
-                        # we only record @jiffies if
-                        # the option is set
-                        if CONSOLE_ARGUMENTS.jiffies:
-                            jiffies = used[uid]["cf:jiffies"]
-                    total_edges += 1
-                    if noencode:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp))
-
-            if "wasGeneratedBy" in json_object:
-                wasGeneratedBy = json_object["wasGeneratedBy"]
-                for uid in wasGeneratedBy:
-                    if "prov:type" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy) record without type: {}".format(uid))
-                        continue
-                    else:
-                        edgetype = "wasGeneratedBy"
-                    if "cf:id" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy) record without logical timestamp: {}".format(uid))
-                        continue
-                    else:
-                        timestamp = wasGeneratedBy[uid]["cf:id"]
-                    if "prov:entity" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy/{}) record without source UUID: {}".format(
-                                wasGeneratedBy[uid]["prov:type"], uid))
-                        continue
-                    if "prov:activity" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy/{}) record without destination UUID: {}".format(
-                                wasGeneratedBy[uid]["prov:type"], uid))
-                        continue
-                    srcUUID = wasGeneratedBy[uid]["prov:activity"]
-                    dstUUID = wasGeneratedBy[uid]["prov:entity"]
-                    if srcUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy/{}) record with an unseen srcUUID: {}".format(
-                                wasGeneratedBy[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        srcVal = node_map[srcUUID]
-                    if dstUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy/{}) record with an unsen dstUUID: {}".format(
-                                wasGeneratedBy[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        dstVal = node_map[dstUUID]
-                    if "cf:date" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy) record without timestamp: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            ts_str = wasGeneratedBy[uid]["cf:date"]
-                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                            adjusted_ts = ts - smallest_timestamp
-                    if "cf:jiffies" not in wasGeneratedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasGeneratedBy) record without jiffies: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.jiffies:
-                            jiffies = wasGeneratedBy[uid]["cf:jiffies"]
-                    total_edges += 1
-                    if noencode:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
-                                                               edgetype, timestamp))
-
-            if "wasInformedBy" in json_object:
-                wasInformedBy = json_object["wasInformedBy"]
-                for uid in wasInformedBy:
-                    if "prov:type" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy) record without type: {}".format(uid))
-                        continue
-                    else:
-                        edgetype = "wasInformedBy"
-                    if "cf:id" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy) record without logical timestamp: {}".format(uid))
-                        continue
-                    else:
-                        timestamp = wasInformedBy[uid]["cf:id"]
-                    if "prov:informant" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy/{}) record without source UUID: {}".format(
-                                wasInformedBy[uid]["prov:type"], uid))
-                        continue
-                    if "prov:informed" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy/{}) record without destination UUID: {}".format(
-                                wasInformedBy[uid]["prov:type"], uid))
-                        continue
-                    srcUUID = wasInformedBy[uid]["prov:informant"]
-                    dstUUID = wasInformedBy[uid]["prov:informed"]
-                    if srcUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy/{}) record with an unseen srcUUID: {}".format(
-                                wasInformedBy[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        srcVal = node_map[srcUUID]
-                    if dstUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy/{}) record with an unseen dstUUID: {}".format(
-                                wasInformedBy[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        dstVal = node_map[dstUUID]
-                    if "cf:date" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy) record without timestamp: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            ts_str = wasInformedBy[uid]["cf:date"]
-                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                            adjusted_ts = ts - smallest_timestamp
-                    if "cf:jiffies" not in wasInformedBy[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasInformedBy) record without jiffies: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.jiffies:
-                            jiffies = wasInformedBy[uid]["cf:jiffies"]
-                    total_edges += 1
-                    if noencode:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
-                                                               edgetype, timestamp))
-
-            if "wasDerivedFrom" in json_object:
-                wasDerivedFrom = json_object["wasDerivedFrom"]
-                for uid in wasDerivedFrom:
-                    if "prov:type" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom) record without type: {}".format(uid))
-                        continue
-                    else:
-                        edgetype = "wasDerivedFrom"
-                    if "cf:id" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom) record without logical timestamp: {}".format(uid))
-                            continue
-                    else:
-                        timestamp = wasDerivedFrom[uid]["cf:id"]
-                    if "prov:usedEntity" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom/{}) record without source UUID: {}".format(
-                                wasDerivedFrom[uid]["prov:type"], uid))
-                        continue
-                    if "prov:generatedEntity" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom/{}) record without destination UUID: {}".format(
-                                wasDerivedFrom[uid]["prov:type"], uid))
-                        continue
-                    srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
-                    dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
-                    if srcUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen srcUUID: {}".format(
-                                wasDerivedFrom[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        srcVal = node_map[srcUUID]
-                    if dstUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen dstUUID: {}".format(
-                                wasDerivedFrom[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        dstVal = node_map[dstUUID]
-                    if "cf:date" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom) record without timestamp: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            ts_str = wasDerivedFrom[uid]["cf:date"]
-                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                            adjusted_ts = ts - smallest_timestamp
-                    if "cf:jiffies" not in wasDerivedFrom[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasDerivedFrom) record without jiffies: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.jiffies:
-                            jiffies = wasDerivedFrom[uid]["cf:jiffies"]
-                    total_edges += 1
-                    if noencode:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
-                                                               edgetype, timestamp))
-
-            if "wasAssociatedWith" in json_object:
-                wasAssociatedWith = json_object["wasAssociatedWith"]
-                for uid in wasAssociatedWith:
-                    if "prov:type" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith) record without type: {}".format(uid))
-                        continue
-                    else:
-                        edgetype = "wasAssociatedWith"
-                    if "cf:id" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith) record without logical timestamp: {}".format(uid))
-                        continue
-                    else:
-                        timestamp = wasAssociatedWith[uid]["cf:id"]
-                    if "prov:agent" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith/{}) record without source UUID: {}".format(
-                                wasAssociatedWith[uid]["prov:type"], uid))
-                        continue
-                    if "prov:activity" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith/{}) record without destination UUID: {}".format(
-                                wasAssociatedWith[uid]["prov:type"], uid))
-                        continue
-                    srcUUID = wasAssociatedWith[uid]["prov:agent"]
-                    dstUUID = wasAssociatedWith[uid]["prov:activity"]
-                    if srcUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen srcUUID: {}".format(
-                                wasAssociatedWith[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        srcVal = node_map[srcUUID]
-                    if dstUUID not in node_map:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen dstUUID: {}".format(
-                                wasAssociatedWith[uid]["prov:type"], uid))
-                        continue
-                    else:
-                        dstVal = node_map[dstUUID]
-                    if "cf:date" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith) record without timestamp: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            ts_str = wasAssociatedWith[uid]["cf:date"]
-                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
-                            adjusted_ts = ts - smallest_timestamp
-                    if "cf:jiffies" not in wasAssociatedWith[uid]:
-                        if CONSOLE_ARGUMENTS.verbose:
-                            logging.debug("edge (wasAssociatedWith) record without jiffies: {}".format(uid))
-                        continue
-                    else:
-                        if CONSOLE_ARGUMENTS.jiffies:
-                            jiffies = wasAssociatedWith[uid]["cf:jiffies"]
-                    total_edges += 1
-                    if noencode:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
-                    else:
-                        if CONSOLE_ARGUMENTS.stats:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  adjusted_ts))
-                        elif CONSOLE_ARGUMENTS.jiffies:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
-                                                                  dstVal, edgetype, timestamp,
-                                                                  jiffies))
-                        else:
-                            output.write(
-                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
-                                                               edgetype, timestamp))
-    f.close()
-    output.close()
-    pb.close()
-    return total_edges
-
-def read_single_graph(file_name, threshold):
-    graph = []
-    edge_cnt = 0
-    with open(file_name, 'r') as f:
-        for line in f:
-            try:
-                edge = line.strip().split("\t")
-                new_edge = [edge[0], edge[1]]
-                attributes = edge[2].strip().split(":")
-                source_node_type = attributes[0]
-                destination_node_type = attributes[1]
-                edge_type = attributes[2]
-                edge_order = attributes[3]
-
-                new_edge.append(source_node_type)
-                new_edge.append(destination_node_type)
-                new_edge.append(edge_type)
-                new_edge.append(edge_order)
-                graph.append(new_edge)
-                edge_cnt += 1
-            except:
-                print("{}".format(line))
-    f.close()
-    graph.sort(key=lambda e: e[5])
-    if len(graph) <= threshold:
-        return graph
-    else:
-        return graph[:threshold]
-
-
-def process_graph(name, threshold):
-    graph = read_single_graph(name, threshold)
-    result_graph = nx.DiGraph()
-    cnt = 0
-    for num, edge in enumerate(graph):
-        cnt += 1
-        src, dst, src_type, dst_type, edge_type = edge[:5]
-        if True:# src_type in valid_node_type and dst_type in valid_node_type:
-            if not result_graph.has_node(src):
-                result_graph.add_node(src, type=src_type)
-            if not result_graph.has_node(dst):
-                result_graph.add_node(dst, type=dst_type)
-            if not result_graph.has_edge(src, dst):
-                result_graph.add_edge(src, dst, type=edge_type)
-                if bidirection:
-                    result_graph.add_edge(dst, src, type='reverse_{}'.format(edge_type))
-    return cnt, result_graph
-
-
-node_type_list = []
-edge_type_list = []
-node_type_dict = {}
-edge_type_dict = {}
-
-
-def format_graph(g, name):
-    new_g = nx.DiGraph()
-    node_map = {}
-    node_cnt = 0
-    for n in g.nodes:
-        node_map[n] = node_cnt
-        new_g.add_node(node_cnt, type=g.nodes[n]['type'])
-        node_cnt += 1
-    for e in g.edges:
-        new_g.add_edge(node_map[e[0]], node_map[e[1]], type=g.edges[e]['type'])
-    for n in new_g.nodes:
-        node_type = new_g.nodes[n]['type']
-        if not node_type in node_type_dict:
-            node_type_list.append(node_type)
-            node_type_dict[node_type] = 1
-        else:
-            node_type_dict[node_type] += 1
-    for e in new_g.edges:
-        edge_type = new_g.edges[e]['type']
-        if not edge_type in edge_type_dict:
-            edge_type_list.append(edge_type)
-            edge_type_dict[edge_type] = 1
-        else:
-            edge_type_dict[edge_type] += 1
-    for n in new_g.nodes:
-        new_g.nodes[n]['type'] = node_type_list.index(new_g.nodes[n]['type'])
-    for e in new_g.edges:
-        new_g.edges[e]['type'] = edge_type_list.index(new_g.edges[e]['type'])
-    with open('{}.json'.format(name), 'w', encoding='utf-8') as f:
-        json.dump(nx.node_link_data(new_g), f)
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description='Convert CamFlow JSON to Unicorn edgelist')
-    args = parser.parse_args()
-    args.stats = False
-    args.verbose = False
-    args.jiffies = False
-    args.input = '../data/wget/raw/'
-    args.output = '../data/wget/processed/'
-    args.final_output = '../data/wget/final/'
-    args.noencode = False
-    if not os.path.exists(args.input):
-        os.mkdir(args.input)
-    if not os.path.exists(args.output):
-        os.mkdir(args.output)
-    if not os.path.exists(args.final_output):
-        os.mkdir(args.final_output)
-    CONSOLE_ARGUMENTS = args
-
-    if args.verbose:
-        logging.basicConfig(filename=args.log, level=logging.DEBUG)
-    cnt = 0
-    for fname in os.listdir(args.input):
-        cnt += 1
-        node_map = dict()
-        parse_all_nodes(args.input + '/{}'.format(fname), node_map)
-        total_edges = parse_all_edges(args.input + '/{}'.format(fname), args.output + '/{}.log'.format(cnt), node_map,
-                                      args.noencode)
-        if args.stats:
-            total_nodes = len(node_map)
-            stats = open(args.stats_file + '/{}.log'.format(cnt), "a+")
-            stats.write("{},{},{}\n".format(args.input + '/{}'.format(fname), total_nodes, total_edges))
-
-    bidirection = False
-    threshold = 10000000 # infinity
-    interaction_dict = []
-    graph_cnt = 0
-    result_graphs = []
-    input = args.output
-    base = args.final_output
-
-    line_cnt = 0
-    for i in tqdm(range(cnt)):
-        single_cnt, result_graph = process_graph('{}{}.log'.format(input, i + 1), threshold)
-        format_graph(result_graph, '{}{}'.format(base, i))
-        line_cnt += single_cnt
-
-    print(line_cnt // 150)
-    print(len(node_type_list))
-    print(node_type_dict)
-    print(len(edge_type_list))
-    print(edge_type_dict)
-