diff --git a/README.md b/README.md
old mode 100644
new mode 100755
index 22b75bfacc1affcdc8d9e94144606c1fc1eadc6e..f40eb3e6a26ba58cdbed986f27f926b54c24ec37
--- a/README.md
+++ b/README.md
@@ -1,15 +1,16 @@
-# CONTINUUM-FEDHE-Graph 
+# FEDHE-Graph 
 
-Welcome to the official repository housing the FEDHE-Graph implementation for our solution Continuum! This repository provides you with the necessary tools and resources to leverage federated learning techniques within the context of Continuum, a comprehensive framework for federated learning research.
+Welcome to the official repository housing the FEDHE-Graph implementation for Magic! This repository provides you with the necessary tools and resources to leverage federated learning techniques within the context of Magic, a comprehensive framework for federated learning research.
 
 ![architecture(1)-1](https://github.com/kamelferrahi/MAGIC_FEDERATED_FedML/assets/72205931/f3e67d1f-2fa1-4800-81e6-7d9c5e509cf7)
-![image](https://github.com/kamelferrahi/Continuum_FL/assets/72205931/fb78accc-df2d-4368-a690-443aba85059a)
 
+Original project: https://github.com/FDUDSDE/MAGIC
 
 ## Environment Setup
 
-The command are used in an environnement that consist of Windows 11 with anaconda installed
+The command are used in an environnement that consist of Ubuntu 22.04 with miniconda installed
 
+Original project: https://github.com/FDUDSDE/MAGIC
 
 First create the conda environnement for fedml with MPI support 
 
@@ -21,13 +22,13 @@ conda install -c conda-forge mpi4py openmpi
 pip install "fedml[MPI]" 
 ```
 
-Clone the Continuum FedML project onto your current folder 
+Clone the MAGIC FedML project onto your current folder 
 
 ```
-git clone https://github.com/kamelferrahi/[MAGIC_FEDERATED_FedML](https://github.com/kamelferrahi/Continuum_FL)
+git clone https://github.com/kamelferrahi/MAGIC_FEDERATED_FedML
 ```
 
-Install the necessary packages for Continuum to run
+Install the necessary packages for Magic to run
 
 ```
 conda install -c conda-forge aiohttp=3.9.1 aiosignal=1.3.1 anyio=4.2.0 attrdict=2.0.1 attrs=23.2.0 blis=0.7.11 boto3=1.34.12 botocore=1.34.12 brotli=1.1.0 catalogue=2.0.10 certifi=2023.11.17 chardet=5.2.0 charset-normalizer=3.3.2 click=8.1.7 cloudpathlib=0.16.0 confection=0.1.4 contourpy=1.2.0 cycler=0.12.1 cymem=2.0.8 dgl=1.1.3 dill=0.3.7 fastapi=0.92.0 fedml=0.8.13.post2 filelock=3.13.1 fonttools=4.47.0 frozenlist=1.4.1 fsspec=2023.12.2 gensim=4.3.2 gevent=23.9.1 geventhttpclient=2.0.9 gitdb=4.0.11 GitPython=3.1.40 GPUtil=1.4.0 graphviz=0.8.4 greenlet=3.0.3 h11=0.14.0 h5py=3.10.0 httpcore=1.0.2 httpx=0.26.0 idna=3.6 Jinja2=3.1.2 jmespath=1.0.1 joblib=1.3.2 kiwisolver=1.4.5 langcodes=3.3.0 MarkupSafe=2.1.3 matplotlib=3.8.2 mpi4py=3.1.3 mpmath=1.3.0 multidict=6.0.4 multiprocess=0.70.15 murmurhash=1.0.10 networkx=2.8.8 ntplib=0.4.0 numpy=1.26.3 nvidia-cublas-cu12=12.1.3.1 nvidia-cuda-cupti-cu12=12.1.105 nvidia-cuda-nvrtc-cu12=12.1.105 nvidia-cuda-runtime-cu12=12.1.105 nvidia-cudnn-cu12=8.9.2.26 nvidia-cufft-cu12=11.0.2.54 nvidia-curand-cu12=10.3.2.106 nvidia-cusolver-cu12=11.4.5.107 nvidia-cusparse-cu12=12.1.0.106 nvidia-nccl-cu12=2.18.1 nvidia-nvjitlink-cu12=12.3.101 nvidia-nvtx-cu12=12.1.105 onnx=1.15.0 packaging=23.2 paho-mqtt=1.6.1 pandas=2.1.4 pathtools=0.1.2 pillow=10.2.0 preshed=3.0.9 prettytable=3.9.0 promise=2.3 protobuf=3.20.3 psutil=5.9.7 py-machineid=0.4.6 pydantic=1.10.13 pyparsing=3.1.1 python-dateutil=2.8.2 python-rapidjson=1.14 pytz=2023.3.post1 PyYAML=6.0.1 redis=5.0.1 requests=2.31.0 s3transfer=0.10.0 scikit-learn=1.3.2 scipy=1.11.4 sentry-sdk=1.39.1 setproctitle=1.3.3 shortuuid=1.0.11 six=1.16.0 smart-open=6.3.0 smmap=5.0.1 sniffio=1.3.0 spacy=3.7.2 spacy-legacy=3.0.12 spacy-loggers=1.0.5 SQLAlchemy=2.0.25 srsly=2.4.8 starlette=0.25.0 sympy=1.12 thinc=8.2.2 threadpoolctl=3.2.0 torch=2.1.2 torch-cluster=1.6.3 torch-scatter=2.1.2 torch-sparse=0.6.18 torch-spline-conv=1.2.2 torch_geometric=2.4.0 torchvision=0.16.2 tqdm=4.66.1 triton=2.1.0 tritonclient=2.41.0 typer=0.9.0 typing_extensions=4.9.0 tzdata=2023.4 tzlocal=5.2 urllib3=2.0.7 uvicorn=0.25.0 wandb=0.13.2 wasabi=1.1.2 wcwidth=0.2.12 weasel=0.3.4 websocket-client=1.7.0 wget=3.2 yarl=1.9.4 zope.event=5.0 zope.interface=6.1
@@ -54,7 +55,7 @@ train_args:
 The algorithm tested are `FedAvg`, `FedProx` and `FedOpt`
 
 ## Datasets
-The experiments utilize datasets similar to those in the original Continuum project. To change datasets, edit the `fedml_config.yaml` file:
+The experiments utilize datasets similar to those in the original Magic project. To change datasets, edit the `fedml_config.yaml` file:
   ```
 data_args:
   dataset: "wget"
diff --git a/a.yaml b/a.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..9f8ce6e7fa889de842853b12c3a723d56ab60899
--- /dev/null
+++ b/a.yaml
@@ -0,0 +1,83 @@
+common_args:
+  training_type: "cross_silo"
+  scenario: "horizontal"
+  using_mlops: false
+  config_version: release
+  name: "exp"
+  project: "runs/train"
+  exist_ok: false
+  random_seed: 0
+  
+  common_args:
+  training_type: "simulation"
+  random_seed: 0
+comm_args:
+  backend: "MPI"
+  is_mobile: 0
+
+
+data_args:
+  dataset: "clintox"
+  data_cache_dir: ~/fedgraphnn_data/
+  part_file:  ~/fedgraphnn_data/partition
+  partition_method: "hetero"
+  partition_alpha: 0.5
+
+model_args:
+  model: "graphsage"
+  hidden_size: 32
+  node_embedding_dim: 32
+  graph_embedding_dim: 64
+  readout_hidden_dim: 64
+  alpha: 0.2
+  num_heads: 2
+  dropout: 0.3
+  normalize_features: False
+  normalize_adjacency: False
+  sparse_adjacency: False
+  model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
+  global_model_file_path: "./model_file_cache/global_model.pt"
+
+environment_args:
+  bootstrap: config/bootstrap.sh
+
+train_args:
+  federated_optimizer: "FedAvg"
+  client_id_list: 
+  client_num_in_total: 3
+  client_num_per_round: 3
+  comm_round: 100
+  epochs: 5
+  batch_size: 64
+  client_optimizer: sgd
+  learning_rate: 0.03
+  weight_decay: 0.001
+  metric: "prc-auc"
+  server_optimizer: sgd
+  lr: 0.001
+  server_lr: 0.001
+  wd: 0.001
+  ci: 0
+  server_momentum: 0.9
+
+validation_args:
+  frequency_of_the_test: 1
+
+device_args:
+  worker_num: 3
+  using_gpu: false
+  gpu_mapping_file: config/gpu_mapping.yaml
+  gpu_mapping_key: mapping_fedgraphnn_sp
+
+comm_args:
+  backend: "MQTT_S3"
+  mqtt_config_path: config/mqtt_config.yaml
+  s3_config_path: config/s3_config.yaml
+
+
+tracking_args:
+   # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
+  enable_wandb: false
+  wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
+  wandb_project: fedml
+  wandb_name: fedml_torch_moleculenet
diff --git a/a.yml b/a.yml
new file mode 100755
index 0000000000000000000000000000000000000000..adfbd8733fcee48a46e87a24cc48cbb1a0241f87
--- /dev/null
+++ b/a.yml
@@ -0,0 +1,68 @@
+common_args:
+  training_type: "simulation"
+  random_seed: 0
+
+data_args:
+  dataset: "clintox"
+  data_cache_dir: ~/fedgraphnn_data/
+  part_file:  ~/fedgraphnn_data/partition
+  partition_method: "hetero"
+  partition_alpha: 0.5
+
+model_args:
+  model: "graphsage"
+  hidden_size: 32
+  node_embedding_dim: 32
+  graph_embedding_dim: 64
+  readout_hidden_dim: 64
+  alpha: 0.2
+  num_heads: 2
+  dropout: 0.3
+  normalize_features: False
+  normalize_adjacency: False
+  sparse_adjacency: False
+  model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
+  global_model_file_path: "./model_file_cache/global_model.pt"
+
+environment_args:
+  bootstrap: config/bootstrap.sh
+
+train_args:
+  federated_optimizer: "FedAvg"
+  client_id_list: 
+  client_num_in_total: 3
+  client_num_per_round: 3
+  comm_round: 100
+  epochs: 5
+  batch_size: 64
+  client_optimizer: sgd
+  learning_rate: 0.03
+  weight_decay: 0.001
+  metric: "prc-auc"
+  server_optimizer: sgd
+  lr: 0.001
+  server_lr: 0.001
+  wd: 0.001
+  ci: 0
+  server_momentum: 0.9
+
+validation_args:
+  frequency_of_the_test: 1
+
+device_args:
+  worker_num: 2
+  using_gpu: false
+  gpu_mapping_file: config/gpu_mapping.yaml
+  gpu_mapping_key: mapping_fedgraphnn_sp
+
+comm_args:
+  backend: "MPI"
+  is_mobile: 0
+
+
+tracking_args:
+   # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
+  enable_wandb: false
+  wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
+  wandb_project: fedml
+  wandb_name: fedml_torch_moleculenet
diff --git a/checkpoints - Copie/checkpoint-SC2.pt b/checkpoints - Copie/checkpoint-SC2.pt
deleted file mode 100644
index 963a001f792613ee2fb2f5f22cb25b599a8dedbb..0000000000000000000000000000000000000000
Binary files a/checkpoints - Copie/checkpoint-SC2.pt and /dev/null differ
diff --git a/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt b/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt
deleted file mode 100644
index 20073e1503d9f57b75c278bdb89c757d3aaccb30..0000000000000000000000000000000000000000
Binary files a/checkpoints - Copie/checkpoint-Unicorn-Cadets.pt and /dev/null differ
diff --git a/checkpoints - Copie/checkpoint-cadets-e3.pt b/checkpoints - Copie/checkpoint-cadets-e3.pt
deleted file mode 100644
index ea7ee25689066513d46aea55e07f71df859ec8c7..0000000000000000000000000000000000000000
Binary files a/checkpoints - Copie/checkpoint-cadets-e3.pt and /dev/null differ
diff --git a/checkpoints - Copie/checkpoint-clearscope-e3.pt b/checkpoints - Copie/checkpoint-clearscope-e3.pt
deleted file mode 100644
index 5edaf9d73bdde92067b7e2f6f0129a4673feb945..0000000000000000000000000000000000000000
Binary files a/checkpoints - Copie/checkpoint-clearscope-e3.pt and /dev/null differ
diff --git a/checkpoints - Copie/checkpoint-streamspot.pt b/checkpoints - Copie/checkpoint-streamspot.pt
deleted file mode 100644
index ec1ed84e3dc16589af58039e87a34f21cc579cd5..0000000000000000000000000000000000000000
Binary files a/checkpoints - Copie/checkpoint-streamspot.pt and /dev/null differ
diff --git a/checkpoints - Copie/checkpoint-theia-e3.pt b/checkpoints - Copie/checkpoint-theia-e3.pt
deleted file mode 100644
index 1513228c205f790835edb723dffe967eec960732..0000000000000000000000000000000000000000
Binary files a/checkpoints - Copie/checkpoint-theia-e3.pt and /dev/null differ
diff --git a/checkpoints - Copie/checkpoint-trace-e3.pt b/checkpoints - Copie/checkpoint-trace-e3.pt
deleted file mode 100644
index 755598d71f01bee477182c9fe454b4646ac7ffcd..0000000000000000000000000000000000000000
Binary files a/checkpoints - Copie/checkpoint-trace-e3.pt and /dev/null differ
diff --git a/checkpoints - Copie/checkpoint-wget-long.pt b/checkpoints - Copie/checkpoint-wget-long.pt
deleted file mode 100644
index f5c8f96c5be068033746e3b9271b983c438bee54..0000000000000000000000000000000000000000
Binary files a/checkpoints - Copie/checkpoint-wget-long.pt and /dev/null differ
diff --git a/checkpoints - Copie/checkpoint-wget.pt b/checkpoints - Copie/checkpoint-wget.pt
deleted file mode 100644
index 1d5f6c0071b7c0d0e136aa6112ade76c083984fc..0000000000000000000000000000000000000000
Binary files a/checkpoints - Copie/checkpoint-wget.pt and /dev/null differ
diff --git a/checkpoints/checkpoint-SC2.pt b/checkpoints/checkpoint-SC2.pt
deleted file mode 100644
index aec48a89437ad5ef4c3199b0547a3efc9365115f..0000000000000000000000000000000000000000
Binary files a/checkpoints/checkpoint-SC2.pt and /dev/null differ
diff --git a/checkpoints/checkpoint-Unicorn-Cadets.pt b/checkpoints/checkpoint-Unicorn-Cadets.pt
deleted file mode 100644
index 90d5dcfbb036da54972648625073287bad1b6986..0000000000000000000000000000000000000000
Binary files a/checkpoints/checkpoint-Unicorn-Cadets.pt and /dev/null differ
diff --git a/checkpoints/checkpoint-cadets-e3 - Copie (2).pt b/checkpoints/checkpoint-cadets-e3 - Copie (2).pt
deleted file mode 100644
index 5f62ae9abed4280d950ebf431a30b2f883b386a2..0000000000000000000000000000000000000000
Binary files a/checkpoints/checkpoint-cadets-e3 - Copie (2).pt and /dev/null differ
diff --git a/checkpoints/checkpoint-cadets-e3.pt b/checkpoints/checkpoint-cadets-e3.pt
deleted file mode 100644
index 157f6b404a8228c9da2483564ab892e0fd4f8a88..0000000000000000000000000000000000000000
Binary files a/checkpoints/checkpoint-cadets-e3.pt and /dev/null differ
diff --git a/checkpoints/checkpoint-clearscope-e3.pt b/checkpoints/checkpoint-clearscope-e3.pt
deleted file mode 100644
index 0ebdb975ec2e7455daed67ab002f1b39a9984c1b..0000000000000000000000000000000000000000
Binary files a/checkpoints/checkpoint-clearscope-e3.pt and /dev/null differ
diff --git a/checkpoints/checkpoint-streamspot.pt b/checkpoints/checkpoint-streamspot.pt
deleted file mode 100644
index 0065a7b8b52665153af952617ab45cde61fe14eb..0000000000000000000000000000000000000000
Binary files a/checkpoints/checkpoint-streamspot.pt and /dev/null differ
diff --git a/checkpoints/checkpoint-theia-e3.pt b/checkpoints/checkpoint-theia-e3.pt
deleted file mode 100644
index 1f0fcc66e18b3e6d534f99e090e5387807451a64..0000000000000000000000000000000000000000
Binary files a/checkpoints/checkpoint-theia-e3.pt and /dev/null differ
diff --git a/checkpoints/checkpoint-trace-e3.pt b/checkpoints/checkpoint-trace-e3.pt
deleted file mode 100644
index f050065e56c009605562ca6e980652b6c2a33938..0000000000000000000000000000000000000000
Binary files a/checkpoints/checkpoint-trace-e3.pt and /dev/null differ
diff --git a/checkpoints/checkpoint-wget.pt b/checkpoints/checkpoint-wget.pt
deleted file mode 100644
index 8404279998dc213fba6b6f626d55c37ccb7c8ca9..0000000000000000000000000000000000000000
Binary files a/checkpoints/checkpoint-wget.pt and /dev/null differ
diff --git a/data/__pycache__/data_loader.cpython-311.pyc b/data/__pycache__/data_loader.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..b17492ecf9ba2a9c984303dae0e410e17b03af83
Binary files /dev/null and b/data/__pycache__/data_loader.cpython-311.pyc differ
diff --git a/data/cadets/graphs.zip b/data/cadets/graphs.zip
new file mode 100755
index 0000000000000000000000000000000000000000..7b014b2ccfa0594ff4dc0e742a2d8453aafeb26d
Binary files /dev/null and b/data/cadets/graphs.zip differ
diff --git a/data/data_loader.py b/data/data_loader.py
new file mode 100755
index 0000000000000000000000000000000000000000..f52db36d726a6d6bd742e77841eb3df28b2de34b
--- /dev/null
+++ b/data/data_loader.py
@@ -0,0 +1,332 @@
+import logging
+import pickle as pkl
+import random
+
+import torch.utils.data as data
+
+from fedml.core import partition_class_samples_with_dirichlet_distribution
+import dgl
+import networkx as nx
+import json
+from tqdm import tqdm
+import os
+import numpy as np
+from utils.loaddata import load_rawdata, load_batch_level_dataset, load_entity_level_dataset, load_metadata
+
+
+class WgetDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(WgetDataset, self).__init__(name=name)
+        if name == 'wget':
+            pathattack = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalattack'
+            pathbenin = '/home/kamel/pfe/fedml/FedML-master/python/examples/federate/prebuilt_jobs/fedgraphnn/wget_magic/data/finalbenin'
+            num_graphs_benin = 125
+            num_graphs_attack = 25
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs_benin)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                self.labels.append(0)
+                
+            for i in tqdm(range(num_graphs_attack)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                self.labels.append(1)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+def darpa_split(name):
+    device = "cpu"
+    path = './data/' + name + '/'
+    metadata = load_metadata(name)
+    n_train = metadata['n_train']
+    train_dataset = []
+    train_labels = []
+    for i in range(n_train):
+        g = load_entity_level_dataset(name, 'train', i).to(device)
+        train_dataset.append(g)
+        train_labels.append(0)
+    
+    return (
+        train_dataset,
+        train_labels,
+        [],
+        [],
+        [],
+        []
+    )
+       
+
+def create_random_split(name):
+    dataset = load_rawdata(name)
+    # Random 80/10/10 split as suggested 
+    train_range = (0, int(0.8 * len(dataset)))
+    val_range = (
+        int(0.8 * len(dataset)),
+        int(0.8 * len(dataset)) + int(0.1 * len(dataset)),
+    )
+    test_range = (
+        int(0.8 * len(dataset)) + int(0.1 * len(dataset)),
+        len(dataset),
+    )
+
+    all_idxs = list(range(len(dataset)))
+    random.shuffle(all_idxs)
+
+    train_dataset = [
+        dataset[all_idxs[i]] for i in range(train_range[0], train_range[1])
+    ]
+   
+    train_labels = [dataset[all_idxs[i]][1] for i in range(train_range[0], train_range[1])]
+
+    val_dataset = [
+        dataset[all_idxs[i]] for i in range(val_range[0], val_range[1])
+    ]
+    val_labels = [dataset[all_idxs[i]][1] for i in range(val_range[0], val_range[1])]
+
+    test_dataset  = [
+        dataset[all_idxs[i]] for i in range(test_range[0], test_range[1])
+    ]
+    test_labels = [dataset[all_idxs[i]][1] for i in range(test_range[0], test_range[1])]
+
+    return (
+        train_dataset,
+        train_labels,
+        val_dataset,
+        val_labels,
+        test_dataset,
+        test_labels,
+    )
+
+
+
+def partition_data_by_sample_size(
+    args, client_number, name, uniform=True, compact=True
+):
+    if (name == 'wget' or name == 'streamspot'):
+        (
+            train_dataset,
+            train_labels,
+            val_dataset,
+            val_labels,
+            test_dataset,
+            test_labels,
+        ) = create_random_split(name)
+    else:
+        (
+            train_dataset,
+            train_labels,
+            val_dataset,
+            val_labels,
+            test_dataset,
+            test_labels,
+        ) = darpa_split(name)
+        
+    num_train_samples = len(train_dataset)
+    num_val_samples = len(val_dataset)
+    num_test_samples = len(test_dataset)
+
+    train_idxs = list(range(num_train_samples))
+    val_idxs = list(range(num_val_samples))
+    test_idxs = list(range(num_test_samples))
+
+    random.shuffle(train_idxs)
+    random.shuffle(val_idxs)
+    random.shuffle(test_idxs)
+
+    partition_dicts = [None] * client_number
+
+    if uniform:
+        clients_idxs_train = np.array_split(train_idxs, client_number)
+        clients_idxs_val = np.array_split(val_idxs, client_number)
+        clients_idxs_test = np.array_split(test_idxs, client_number)
+    else:
+        clients_idxs_train = create_non_uniform_split(
+            args, train_idxs, client_number, True
+        )
+        clients_idxs_val = create_non_uniform_split(
+            args, val_idxs, client_number, False
+        )
+        clients_idxs_test = create_non_uniform_split(
+            args, test_idxs, client_number, False
+        )
+
+    labels_of_all_clients = []
+    for client in range(client_number):
+        client_train_idxs = clients_idxs_train[client]
+        client_val_idxs = clients_idxs_val[client]
+        client_test_idxs = clients_idxs_test[client]
+
+        train_dataset_client = [
+            train_dataset[idx] for idx in client_train_idxs
+        ]
+        train_labels_client = [train_labels[idx] for idx in client_train_idxs]
+        labels_of_all_clients.append(train_labels_client)
+
+        val_dataset_client = [val_dataset[idx] for idx in client_val_idxs]
+        val_labels_client = [val_labels[idx] for idx in client_val_idxs]
+
+        test_dataset_client = [test_dataset[idx] for idx in client_test_idxs]
+        test_labels_client = [test_labels[idx] for idx in client_test_idxs]
+
+
+        partition_dict = {
+            "train": train_dataset_client,
+            "val": val_dataset_client,
+            "test": test_dataset_client,
+        }
+
+        partition_dicts[client] = partition_dict
+    global_data_dict = {
+        "train": train_dataset,
+        "val": val_dataset,
+        "test": test_dataset,
+    }
+
+    return global_data_dict, partition_dicts
+
+def load_partition_data(
+    args,
+    client_number,
+    name,
+    uniform=True,
+    global_test=True,
+    compact=True,
+    normalize_features=False,
+    normalize_adj=False,
+):
+    global_data_dict, partition_dicts = partition_data_by_sample_size(
+        args, client_number, name, uniform, compact=compact
+    )
+
+    data_local_num_dict = dict()
+    train_data_local_dict = dict()
+    val_data_local_dict = dict()
+    test_data_local_dict = dict()
+
+  
+
+    # IT IS VERY IMPORTANT THAT THE BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE.
+    train_data_global =  global_data_dict["train"]
+    val_data_global = global_data_dict["val"]
+    test_data_global = global_data_dict["test"]
+    train_data_num = len(global_data_dict["train"])
+    val_data_num = len(global_data_dict["val"])
+    test_data_num = len(global_data_dict["test"])
+
+    for client in range(client_number):
+        train_dataset_client = partition_dicts[client]["train"]
+        val_dataset_client = partition_dicts[client]["val"]
+        test_dataset_client = partition_dicts[client]["test"]
+
+        data_local_num_dict[client] = len(train_dataset_client)
+        train_data_local_dict[client] = train_dataset_client,
+ 
+        val_data_local_dict[client] = val_dataset_client
+  
+        test_data_local_dict[client] = (
+            test_data_global
+            if global_test
+            else test_dataset_client
+                
+        )
+
+        logging.info(
+            "Client idx = {}, local sample number = {}".format(
+                client, len(train_dataset_client)
+            )
+        )
+
+    return (
+        train_data_num,
+        val_data_num,
+        test_data_num,
+        train_data_global,
+        val_data_global,
+        test_data_global,
+        data_local_num_dict,
+        train_data_local_dict,
+        val_data_local_dict,
+        test_data_local_dict,
+    )
+
+
+
+def load_batch_level_dataset_main(name):
+    dataset = get_data(name)
+    graph, _ = dataset[0]
+    node_feature_dim = 0
+    for g, _ in dataset:
+        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
+    edge_feature_dim = 0
+    for g, _ in dataset:
+        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
+    node_feature_dim += 1
+    edge_feature_dim += 1
+    full_dataset = [i for i in range(len(dataset))]
+    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
+    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
+
+    return {'dataset': dataset,
+            'train_index': train_dataset,
+            'full_index': full_dataset,
+            'n_feat': node_feature_dim,
+            'e_feat': edge_feature_dim}
+            
+            
+class GraphDataset(dgl.data.DGLDataset):
+    def __init__(self, graph_label_list):
+        super(GraphDataset, self).__init__(name="wget")
+        self.graph_label_list = graph_label_list
+
+    def __len__(self):
+        return len(self.graph_label_list)
+
+    def __getitem__(self, idx):
+        graph, label = self.graph_label_list[idx]
+        # Convert the graph to a DGLGraph to work with DGL
+        return graph, label
+
+            
+def transform_data(data):
+    dataset = GraphDataset(data[0])
+    graph, _ = dataset[0]
+    node_feature_dim = 0
+    for g, _ in dataset:
+        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
+    edge_feature_dim = 0
+    for g, _ in dataset:
+        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
+    node_feature_dim += 1
+    edge_feature_dim += 1
+    full_dataset = [i for i in range(len(dataset))]
+    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
+    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
+
+    return {'dataset': dataset,
+            'train_index': train_dataset,
+            'full_index': full_dataset,
+            'n_feat': node_feature_dim,
+            'e_feat': edge_feature_dim}
+
diff --git a/data/theia/graphs.zip b/data/theia/graphs.zip
new file mode 100755
index 0000000000000000000000000000000000000000..3a98e4af6ec79f7199a8d843ff734bdce4ac59f9
Binary files /dev/null and b/data/theia/graphs.zip differ
diff --git a/checkpoints/checkpoint-wget-long.pt b/data/trace/graphs.zip
old mode 100644
new mode 100755
similarity index 60%
rename from checkpoints/checkpoint-wget-long.pt
rename to data/trace/graphs.zip
index 040897b99810c72f6bb6d2314aab468b3972de3e..98fdf390d57c6cec2ee07134ef7d4fe25accac2f
Binary files a/checkpoints/checkpoint-wget-long.pt and b/data/trace/graphs.zip differ
diff --git a/distance_save_cadets.pkl b/distance_save_cadets.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..1f2f428d6e2e6ffb8afbfd0d0669c46cab8718c6
Binary files /dev/null and b/distance_save_cadets.pkl differ
diff --git a/eval_result/distance_save_cadets-e3 - Copie.pkl b/eval_result/distance_save_cadets-e3 - Copie.pkl
deleted file mode 100644
index 7f737190e44abb34ebea3e6763c64b9f1ae3a89a..0000000000000000000000000000000000000000
Binary files a/eval_result/distance_save_cadets-e3 - Copie.pkl and /dev/null differ
diff --git a/eval_result/distance_save_cadets-e3.pkl b/eval_result/distance_save_cadets-e3.pkl
deleted file mode 100644
index 10c18f8c50bfce89a1afd0811cc24ce8e0595786..0000000000000000000000000000000000000000
Binary files a/eval_result/distance_save_cadets-e3.pkl and /dev/null differ
diff --git a/eval_result/distance_save_cadets.pkl b/eval_result/distance_save_cadets.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..2bb4de8e00a64a22e53b1a5187821e6e53dfffcc
Binary files /dev/null and b/eval_result/distance_save_cadets.pkl differ
diff --git a/eval_result/distance_save_theia-e3.pkl b/eval_result/distance_save_theia-e3.pkl
deleted file mode 100644
index c5d31bab1acc6e943c746a154eec3a9370ccb644..0000000000000000000000000000000000000000
Binary files a/eval_result/distance_save_theia-e3.pkl and /dev/null differ
diff --git a/eval_result/distance_save_trace-e3 - Copie.pkl b/eval_result/distance_save_trace-e3 - Copie.pkl
deleted file mode 100644
index 709b4f83bc0d8cd4d8be175fed1c54d366c9de16..0000000000000000000000000000000000000000
Binary files a/eval_result/distance_save_trace-e3 - Copie.pkl and /dev/null differ
diff --git a/eval_result/distance_save_trace-e3.pkl b/eval_result/distance_save_trace-e3.pkl
deleted file mode 100644
index 709b4f83bc0d8cd4d8be175fed1c54d366c9de16..0000000000000000000000000000000000000000
Binary files a/eval_result/distance_save_trace-e3.pkl and /dev/null differ
diff --git a/fedml_config.yaml b/fedml_config.yaml
old mode 100644
new mode 100755
index 40cd616a925f3514dda97dd8ba1b20aaead81744..66714a20c7a2b471676dac8be0801b3ae132e5cf
--- a/fedml_config.yaml
+++ b/fedml_config.yaml
@@ -1,53 +1,49 @@
 common_args:
-  training_type: "cross_silo"
-  scenario: "horizontal"
-  using_mlops: false
-  config_version: release
-  name: "exp"
-  project: "runs/train"
-  exist_ok: false
+  training_type: "simulation"
   random_seed: 0
   
 data_args:
-  dataset: "trace-e3"
+  dataset: "wget"
+  data_cache_dir: ~/fedgraphnn_data/
+  part_file:  ~/fedgraphnn_data/partition
 
 model_args:
   model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
   global_model_file_path: "./model_file_cache/global_model.pt"
 
+environment_args:
+  bootstrap: config/bootstrap.sh
 
 train_args:
   federated_optimizer: "FedAvg"
   client_id_list: 
-  client_num_in_total: 2
-  client_num_per_round: 2
-  comm_round: 1
-  snapshot: 1
+  client_num_in_total: 4
+  client_num_per_round: 4
+  comm_round: 100
+  lr: 0.001
+  server_lr: 0.001
+  wd: 0.001
+  ci: 0
+  server_momentum: 0.9
 
 validation_args:
   frequency_of_the_test: 1
 
 device_args:
-  worker_num: 2
-  using_gpu: true
-  gpu_mapping_file: gpu_mapping.yaml
-  gpu_mapping_key: mapping_config
+  worker_num: 4
+  using_gpu: false
+  gpu_mapping_file: config/gpu_mapping.yaml
+  gpu_mapping_key: mapping_fedgraphnn_sp
 
 
 comm_args:
-  backend: "MQTT_S3"
-  mqtt_config_path: config/mqtt_config.yaml
+  backend: "MPI"
+  is_mobile: 0
+
 
 tracking_args:
    # When running on MLOps platform(open.fedml.ai), the default log path is at ~/.fedml/fedml-client/fedml/logs/ and ~/.fedml/fedml-server/fedml/logs/
   enable_wandb: false
   wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
   wandb_project: fedml
-  wandb_name: fedml_torch
-
-# fhe_args:
-# #   enable_fhe: true
- #  scheme: ckks
-#  batch_size: 8192
-#   scaling_factor: 52
-#   file_loc: "resources/cryptoparams/"
\ No newline at end of file
+  wandb_name: fedml_torch_moleculenet
diff --git a/gpu_mapping.yaml b/gpu_mapping.yaml
deleted file mode 100644
index 90b13bfe6abfa84b56845ac220a5c67ced781cbe..0000000000000000000000000000000000000000
--- a/gpu_mapping.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-mapping_config:
-    host1: [3]
\ No newline at end of file
diff --git a/main.py b/main.py
old mode 100644
new mode 100755
index 7d7430d7fb88d4259d8096266633bb5dbd94bc8f..a1a716e17f69f73d33fa62ef5b3da1e85b90b715
--- a/main.py
+++ b/main.py
@@ -1,18 +1,20 @@
 import logging
 
 import fedml
-from utils.dataloader import load_partition_data, load_data, load_metadata, darpa_split
+from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split
 from fedml import FedMLRunner
 from trainer.magic_trainer import MagicTrainer
 from trainer.magic_aggregator import MagicWgetAggregator
-from model.model import STGNN_AutoEncoder
+from model.autoencoder import build_model
 from utils.config import build_args
 from trainer.magic_trainer import MagicTrainer
 from trainer.magic_aggregator import MagicWgetAggregator
+from trainer.single_trainer import train_single
+from utils.loaddata import load_batch_level_dataset, load_metadata
 
 
 
-def generate_dataset(name, number, nsnapshot):
+def generate_dataset(name, number):
     (
         train_data_num,
         val_data_num,
@@ -24,7 +26,7 @@ def generate_dataset(name, number, nsnapshot):
         train_data_local_dict,
         val_data_local_dict,
         test_data_local_dict,
-    ) = load_partition_data(number, name, nsnapshot) 
+    ) = load_partition_data(None, number, name) 
     dataset = [
         train_data_num,
         test_data_num,
@@ -36,9 +38,9 @@ def generate_dataset(name, number, nsnapshot):
         len(train_data_global),
     ]
     
-    if (name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']):
+    if (name == "wget" or name == "streamspot"):
         
-        return dataset, load_data(name)
+        return dataset, load_batch_level_dataset(name)
     else:
         return dataset, load_metadata(name) 
            
@@ -47,47 +49,39 @@ if __name__ == "__main__":
     # init FedML framework
     args = fedml.init()
     # init device
-
     device = fedml.device.get_device(args)
-    dataset_name = args.dataset
+    name = args.dataset
     number = args.client_num_in_total
-    nsnapshot = args.snapshot
-    dataset, metadata = generate_dataset(dataset_name, number, nsnapshot)
+    
+    dataset, metadata = generate_dataset(name, number)
     main_args = build_args()
-    if (dataset_name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']):
-            
-        main_args.max_epoch = 6
-        out_dim = 64
-        if (dataset_name == 'SC2'):
-            gnn_layer = 3
-        else:
-            gnn_layer = 5
+    if (name == "wget"):
+        main_args["num_hidden"] = 256
+        main_args["max_epoch"] = 2
+        main_args["num_layers"] = 4
         n_node_feat = metadata['n_feat']
         n_edge_feat = metadata['e_feat']
-        use_all_hidden =  True
-        main_args.n_dim = n_node_feat
-        main_args.e_dim = n_edge_feat
-    else: 
-        use_all_hidden =  False
-        n_node_feat = metadata['node_feature_dim']
-        n_edge_feat = metadata['edge_feature_dim']
-            #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148]
-        main_args.n_dim = n_node_feat
-        main_args.e_dim = n_edge_feat
-        main_args.max_epoch = 50
-        out_dim = 64
-
-        if (dataset_name == 'cadets-e3'):
-            gnn_layer = 4
-        else:
-            gnn_layer = 3
-
-        
-
-    model = STGNN_AutoEncoder(main_args.n_dim, main_args.e_dim, out_dim, out_dim, gnn_layer, 4, device,  nsnapshot, 'prelu', 0.1, main_args.negative_slope, True, 'BatchNorm', main_args.pooling, alpha_l=main_args.alpha_l, use_all_hidden=use_all_hidden).to(device)  # Move model to GPU
+        main_args["n_dim"] = n_node_feat
+        main_args["e_dim"] = n_edge_feat
+    elif (name == "streamspot"):
+        main_args["num_hidden"] = 256
+        main_args["max_epoch"] = 5
+        main_args["num_layers"] = 4
+        n_node_feat = metadata['n_feat']
+        n_edge_feat = metadata['e_feat']
+        main_args["n_dim"] = n_node_feat
+        main_args["e_dim"] = n_edge_feat
+    else:
+        main_args["num_hidden"] = 64
+        main_args["max_epoch"] = 50
+        main_args["num_layers"] = 3
+        main_args["n_dim"] = metadata["node_feature_dim"]
+        main_args["e_dim"] = metadata["edge_feature_dim"]
+    
+    model = build_model(main_args)
     #train_single(main_args, model, data)
-    trainer = MagicTrainer(model, args, dataset_name)
-    aggregator = MagicWgetAggregator(model, args, dataset_name)
+    trainer = MagicTrainer(model, args, name)
+    aggregator = MagicWgetAggregator(model, args, name)
     fedml_runner = FedMLRunner(args, device, dataset, model, trainer, aggregator)
     fedml_runner.run()
     # start training
diff --git a/model/__pycache__/autoencoder.cpython-311.pyc b/model/__pycache__/autoencoder.cpython-311.pyc
index 3abc72625245a57fa06470e3659c594a0c033046..13b41e99adb07cc4cfa61dd2771b34854aa88d80 100644
Binary files a/model/__pycache__/autoencoder.cpython-311.pyc and b/model/__pycache__/autoencoder.cpython-311.pyc differ
diff --git a/model/__pycache__/eval.cpython-310.pyc b/model/__pycache__/eval.cpython-310.pyc
deleted file mode 100644
index 0dc58b1214be9d335558c42907cf1feb18b81969..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/eval.cpython-310.pyc and /dev/null differ
diff --git a/model/__pycache__/eval.cpython-311.pyc b/model/__pycache__/eval.cpython-311.pyc
old mode 100644
new mode 100755
index faade8397a5b3c546b35f34aef088d8992486510..c7d0bb5b6623e0d409dc28636e2722f423c0c7f0
Binary files a/model/__pycache__/eval.cpython-311.pyc and b/model/__pycache__/eval.cpython-311.pyc differ
diff --git a/model/__pycache__/gat.cpython-310.pyc b/model/__pycache__/gat.cpython-310.pyc
deleted file mode 100644
index 54fa4abdd4680faa41fbf0e19697fab2f000402b..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/gat.cpython-310.pyc and /dev/null differ
diff --git a/model/__pycache__/gat.cpython-311.pyc b/model/__pycache__/gat.cpython-311.pyc
old mode 100644
new mode 100755
index 94e6046f4f22c6a2f271c2ca3bded5d8f2417cfa..0c9997f570d123655a70df66a86c3003e82f25ab
Binary files a/model/__pycache__/gat.cpython-311.pyc and b/model/__pycache__/gat.cpython-311.pyc differ
diff --git a/model/__pycache__/loss_func.cpython-310.pyc b/model/__pycache__/loss_func.cpython-310.pyc
deleted file mode 100644
index 79ac15810418709637606b15c5c43b4b96aec691..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/loss_func.cpython-310.pyc and /dev/null differ
diff --git a/model/__pycache__/loss_func.cpython-311.pyc b/model/__pycache__/loss_func.cpython-311.pyc
old mode 100644
new mode 100755
index 40b3c9d58012a8600cda16f7d39ee5c5891fa862..620a6bee7cb3ab8dc89bbc3448fdac41abb5f5ea
Binary files a/model/__pycache__/loss_func.cpython-311.pyc and b/model/__pycache__/loss_func.cpython-311.pyc differ
diff --git a/model/__pycache__/model.cpython-310.pyc b/model/__pycache__/model.cpython-310.pyc
deleted file mode 100644
index 3984a3444fb13f56ea0c2629e2b0883298bf8f4d..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/model.cpython-310.pyc and /dev/null differ
diff --git a/model/__pycache__/model.cpython-311.pyc b/model/__pycache__/model.cpython-311.pyc
deleted file mode 100644
index d9a27f13ee6839a64257d3125b777544a866f329..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/model.cpython-311.pyc and /dev/null differ
diff --git a/model/__pycache__/rnn.cpython-310.pyc b/model/__pycache__/rnn.cpython-310.pyc
deleted file mode 100644
index d02d9b442912cda476cacbc18673eebe9e000da7..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/rnn.cpython-310.pyc and /dev/null differ
diff --git a/model/__pycache__/rnn.cpython-311.pyc b/model/__pycache__/rnn.cpython-311.pyc
deleted file mode 100644
index 9d71283ae5cbfc01ec9eb9f1448d8e8b292d1f08..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/rnn.cpython-311.pyc and /dev/null differ
diff --git a/model/__pycache__/test.cpython-310.pyc b/model/__pycache__/test.cpython-310.pyc
deleted file mode 100644
index 1d85be9aa8853ef325090a1b1cd0f366d3224792..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/test.cpython-310.pyc and /dev/null differ
diff --git a/model/__pycache__/test.cpython-311.pyc b/model/__pycache__/test.cpython-311.pyc
deleted file mode 100644
index 0d7beba377fb059135e3135b292a910477636c7c..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/test.cpython-311.pyc and /dev/null differ
diff --git a/model/__pycache__/train.cpython-311.pyc b/model/__pycache__/train.cpython-311.pyc
old mode 100644
new mode 100755
index b788a354802ea312364f91691db5b00e502922a6..c202e2105beb34aa9db7a6eddeffd4cef4b25fa3
Binary files a/model/__pycache__/train.cpython-311.pyc and b/model/__pycache__/train.cpython-311.pyc differ
diff --git a/model/__pycache__/train_entity.cpython-310.pyc b/model/__pycache__/train_entity.cpython-310.pyc
deleted file mode 100644
index cbe4cebe4c9d3a09cdc4a8b78f976f98a23baef1..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/train_entity.cpython-310.pyc and /dev/null differ
diff --git a/model/__pycache__/train_entity.cpython-311.pyc b/model/__pycache__/train_entity.cpython-311.pyc
deleted file mode 100644
index cf1b64f3a76e749403864eb4a4a10154a8003a84..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/train_entity.cpython-311.pyc and /dev/null differ
diff --git a/model/__pycache__/train_graph.cpython-310.pyc b/model/__pycache__/train_graph.cpython-310.pyc
deleted file mode 100644
index cf931f350c1de50fb725f95c7ebc73f1c6facc66..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/train_graph.cpython-310.pyc and /dev/null differ
diff --git a/model/__pycache__/train_graph.cpython-311.pyc b/model/__pycache__/train_graph.cpython-311.pyc
deleted file mode 100644
index d98093beaf11c1ffeb71b9abc556e33503934038..0000000000000000000000000000000000000000
Binary files a/model/__pycache__/train_graph.cpython-311.pyc and /dev/null differ
diff --git a/model/autoencoder.py b/model/autoencoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..da51f590f02076ad43a0f8f40037ae4ad5b4c761
--- /dev/null
+++ b/model/autoencoder.py
@@ -0,0 +1,179 @@
+from .gat import GAT
+from utils.utils import create_norm
+from functools import partial
+from itertools import chain
+from .loss_func import sce_loss
+import torch
+import torch.nn as nn
+import dgl
+import random
+
+
+def build_model(args):
+    num_hidden = args["num_hidden"]
+    num_layers = args["num_layers"]
+    negative_slope = args["negative_slope"]
+    mask_rate = args["mask_rate"]
+    alpha_l = args["alpha_l"]
+    n_dim = args["n_dim"]
+    e_dim = args["e_dim"]
+
+    model = GMAEModel(
+        n_dim=n_dim,
+        e_dim=e_dim,
+        hidden_dim=num_hidden,
+        n_layers=num_layers,
+        n_heads=4,
+        activation="prelu",
+        feat_drop=0.1,
+        negative_slope=negative_slope,
+        residual=True,
+        mask_rate=mask_rate,
+        norm='BatchNorm',
+        loss_fn='sce',
+        alpha_l=alpha_l
+    )
+    return model
+
+
+class GMAEModel(nn.Module):
+    def __init__(self, n_dim, e_dim, hidden_dim, n_layers, n_heads, activation,
+                 feat_drop, negative_slope, residual, norm, mask_rate=0.5, loss_fn="sce", alpha_l=2):
+        super(GMAEModel, self).__init__()
+        self._mask_rate = mask_rate
+        self._output_hidden_size = hidden_dim
+        self.recon_loss = nn.BCELoss(reduction='mean')
+
+        def init_weights(m):
+            if isinstance(m, nn.Linear):
+                nn.init.xavier_uniform(m.weight)
+                nn.init.constant_(m.bias, 0)
+
+        self.edge_recon_fc = nn.Sequential(
+            nn.Linear(hidden_dim * n_layers * 2, hidden_dim),
+            nn.LeakyReLU(negative_slope),
+            nn.Linear(hidden_dim, 1),
+            nn.Sigmoid()
+        )
+        self.edge_recon_fc.apply(init_weights)
+
+        assert hidden_dim % n_heads == 0
+        enc_num_hidden = hidden_dim // n_heads
+        enc_nhead = n_heads
+
+        dec_in_dim = hidden_dim
+        dec_num_hidden = hidden_dim
+
+        # build encoder
+        self.encoder = GAT(
+            n_dim=n_dim,
+            e_dim=e_dim,
+            hidden_dim=enc_num_hidden,
+            out_dim=enc_num_hidden,
+            n_layers=n_layers,
+            n_heads=enc_nhead,
+            n_heads_out=enc_nhead,
+            concat_out=True,
+            activation=activation,
+            feat_drop=feat_drop,
+            attn_drop=0.0,
+            negative_slope=negative_slope,
+            residual=residual,
+            norm=create_norm(norm),
+            encoding=True,
+        )
+
+        # build decoder for attribute prediction
+        self.decoder = GAT(
+            n_dim=dec_in_dim,
+            e_dim=e_dim,
+            hidden_dim=dec_num_hidden,
+            out_dim=n_dim,
+            n_layers=1,
+            n_heads=n_heads,
+            n_heads_out=1,
+            concat_out=True,
+            activation=activation,
+            feat_drop=feat_drop,
+            attn_drop=0.0,
+            negative_slope=negative_slope,
+            residual=residual,
+            norm=create_norm(norm),
+            encoding=False,
+        )
+
+        self.enc_mask_token = nn.Parameter(torch.zeros(1, n_dim))
+        self.encoder_to_decoder = nn.Linear(dec_in_dim * n_layers, dec_in_dim, bias=False)
+
+        # * setup loss function
+        self.criterion = self.setup_loss_fn(loss_fn, alpha_l)
+
+    @property
+    def output_hidden_dim(self):
+        return self._output_hidden_size
+
+    def setup_loss_fn(self, loss_fn, alpha_l):
+        if loss_fn == "sce":
+            criterion = partial(sce_loss, alpha=alpha_l)
+        else:
+            raise NotImplementedError
+        return criterion
+
+    def encoding_mask_noise(self, g, mask_rate=0.3):
+        new_g = g.clone()
+        num_nodes = g.num_nodes()
+        perm = torch.randperm(num_nodes, device=g.device)
+
+        # random masking
+        num_mask_nodes = int(mask_rate * num_nodes)
+        mask_nodes = perm[: num_mask_nodes]
+        keep_nodes = perm[num_mask_nodes:]
+
+        new_g.ndata["attr"][mask_nodes] = self.enc_mask_token
+
+        return new_g, (mask_nodes, keep_nodes)
+
+    def forward(self, g):
+        loss = self.compute_loss(g)
+        return loss
+
+    def compute_loss(self, g):
+        # Feature Reconstruction
+        pre_use_g, (mask_nodes, keep_nodes) = self.encoding_mask_noise(g, self._mask_rate)
+        pre_use_x = pre_use_g.ndata['attr'].to(pre_use_g.device)
+        use_g = pre_use_g
+        enc_rep, all_hidden = self.encoder(use_g, pre_use_x, return_hidden=True)
+        enc_rep = torch.cat(all_hidden, dim=1)
+        rep = self.encoder_to_decoder(enc_rep)
+
+        recon = self.decoder(pre_use_g, rep)
+        x_init = g.ndata['attr'][mask_nodes]
+        x_rec = recon[mask_nodes]
+        loss = self.criterion(x_rec, x_init)
+
+        # Structural Reconstruction
+        threshold = min(10000, g.num_nodes())
+
+        negative_edge_pairs = dgl.sampling.global_uniform_negative_sampling(g, threshold)
+        positive_edge_pairs = random.sample(range(g.number_of_edges()), threshold)
+        positive_edge_pairs = (g.edges()[0][positive_edge_pairs], g.edges()[1][positive_edge_pairs])
+        sample_src = enc_rep[torch.cat([positive_edge_pairs[0], negative_edge_pairs[0]])].to(g.device)
+        sample_dst = enc_rep[torch.cat([positive_edge_pairs[1], negative_edge_pairs[1]])].to(g.device)
+        y_pred = self.edge_recon_fc(torch.cat([sample_src, sample_dst], dim=-1)).squeeze(-1)
+        y = torch.cat([torch.ones(len(positive_edge_pairs[0])), torch.zeros(len(negative_edge_pairs[0]))]).to(
+            g.device)
+        loss += self.recon_loss(y_pred, y)
+        return loss
+
+    def embed(self, g):
+        x = g.ndata['attr'].to(g.device)
+        rep = self.encoder(g, x)
+        return rep
+
+    @property
+    def enc_params(self):
+        return self.encoder.parameters()
+
+    @property
+    def dec_params(self):
+        return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])
diff --git a/model/eval.py b/model/eval.py
old mode 100644
new mode 100755
index 7446db31e19a2e144eb061e0526cb44e23798935..ff777aeae5f9a8664fe89c384334eb93d02c5f44
--- a/model/eval.py
+++ b/model/eval.py
@@ -7,35 +7,28 @@ import numpy as np
 from sklearn.metrics import roc_auc_score, precision_recall_curve
 from sklearn.neighbors import NearestNeighbors
 from utils.utils import set_random_seed
-from utils.dataloader import load_graph, load_data
-import pickle as pkl
+from data.data_loader import load_batch_level_dataset_main
+from utils.loaddata import transform_graph, load_batch_level_dataset
 
-def batch_level_evaluation(model, pooler, device, method, dataset, n_dim=0, e_dim=0):
-    print('Start Evaluation')
 
+def batch_level_evaluation(model, pooler, device, method, dataset, n_dim=0, e_dim=0):
     model.eval()
     x_list = []
     y_list = []
-    data = load_data(dataset)
+    data = load_batch_level_dataset(dataset)
     full = data['full_index']
-    labels = data['labels']
+    graphs = data['dataset']
     with torch.no_grad():
         for i in full:
-            #break
-            g = load_graph(i, dataset, device)
-            label = labels[i]
+            g = transform_graph(graphs[i][0], n_dim, e_dim).to(device)
+            label = graphs[i][1]
             out = model.embed(g)
             if dataset != 'wget':
-                out = pooler(g[-1], out).cpu().numpy()
+                out = pooler(g, out).cpu().numpy()
             else:
-                out = pooler(g[-1], out, [1]).cpu().numpy()
+                out = pooler(g, out, [2]).cpu().numpy()
             y_list.append(label)
             x_list.append(out)
-    
-    #pkl.dump(x_list,open('xlist.pkl','wb') )
-    #pkl.dump(y_list,open('ylist.pkl','wb') )
-    #x_list = pkl.load(open('xlist.pkl','rb'))
-    #y_list = pkl.load(open('ylist.pkl','rb'))
     x = np.concatenate(x_list, axis=0)
     y = np.array(y_list)
     if 'knn' in method:
@@ -49,18 +42,9 @@ def evaluate_batch_level_using_knn(repeat, dataset, embeddings, labels):
     x, y = embeddings, labels
     if dataset == 'streamspot':
         train_count = 400
-    elif (dataset == 'Unicorn-Cadets' or dataset == 'wget-long'):
-        train_count = 70
-    elif (dataset == 'wget' or dataset == 'SC2'): 
-        train_count = 100
-    else:
-        train_count = 30
-
-    if (dataset =='SC2'):
-        n_neighbors = min(int(train_count * 0.02), 10)
     else:
-        n_neighbors = 100
-
+        train_count = 100
+    n_neighbors = min(int(train_count * 0.02), 10)
     benign_idx = np.where(y == 0)[0]
     attack_idx = np.where(y == 1)[0]
     if repeat != -1:
@@ -134,65 +118,50 @@ def evaluate_batch_level_using_knn(repeat, dataset, embeddings, labels):
         np.random.shuffle(benign_idx)
         np.random.shuffle(attack_idx)
         x_train = x[benign_idx[:train_count]]
-        #x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0)
         x_test = np.concatenate([x[benign_idx[train_count:]], x[attack_idx]], axis=0)
         y_test = np.concatenate([y[benign_idx[train_count:]], y[attack_idx]], axis=0)
         x_train_mean = x_train.mean(axis=0)
         x_train_std = x_train.std(axis=0)
+        for i in range(len(x_train_std)):
+           if (x_train_std[i] == 0 ):
+               x_train_std[i] = 0.000000000000001
         x_train = (x_train - x_train_mean) / x_train_std
         x_test = (x_test - x_train_mean) / x_train_std
-        f1_max = 0
-        for n_neighbors in range(1, train_count):
-            nbrs = NearestNeighbors(n_neighbors=n_neighbors)
-            nbrs.fit(x_train)
-            distances, indexes = nbrs.kneighbors(x_train, n_neighbors=n_neighbors)
-            mean_distance = distances.mean() * n_neighbors / (n_neighbors - 1)
-            #mean_distance = 0.1
-            distances, indexes = nbrs.kneighbors(x_test, n_neighbors=n_neighbors)
-
-            score = distances.mean(axis=1) / mean_distance
-            auc = roc_auc_score(y_test, score)
-            prec, rec, threshold = precision_recall_curve(y_test, score)
-            f1 = 2 * prec * rec / (rec + prec + 1e-9)
-            best_idx = np.argmax(f1)
-            best_thres = threshold[best_idx]
-
-            tn = 0
-            fn = 0
-            tp = 0
-            fp = 0
-
-            for i in range(len(y_test)):
-                if y_test[i] == 1.0 and score[i] >= best_thres:
-                    tp += 1
-                if y_test[i] == 1.0 and score[i] < best_thres:
-                    fn += 1
-                if y_test[i] == 0.0 and score[i] < best_thres:
 
-                    tn += 1
-                if y_test[i] == 0.0 and score[i] >= best_thres:
-                    fp += 1
-            
-            if (f1[best_idx]> f1_max):
-                f1_max = f1[best_idx]
-                auc_max = auc
-                prec_max =  prec[best_idx]
-                rec_max = rec[best_idx]
-                tn_max = tn 
-                fn_max = fn
-                tp_max = tp
-                fp_max = fp
-                best_n = n_neighbors
-
-        print('AUC: {}'.format(auc_max))
-        print('F1: {}'.format(f1_max))
-        print('PRECISION: {}'.format(prec_max))
-        print('RECALL: {}'.format(rec_max))
-        print('TN: {}'.format(tn_max))
-        print('FN: {}'.format(fn_max))
-        print('TP: {}'.format(tp_max))
-        print('FP: {}'.format(fp_max))
-        print(best_n)
+        nbrs = NearestNeighbors(n_neighbors=n_neighbors)
+        nbrs.fit(x_train)
+        distances, indexes = nbrs.kneighbors(x_train, n_neighbors=n_neighbors)
+        mean_distance = distances.mean() * n_neighbors / (n_neighbors - 1)
+        distances, indexes = nbrs.kneighbors(x_test, n_neighbors=n_neighbors)
+
+        score = distances.mean(axis=1) / mean_distance
+        auc = roc_auc_score(y_test, score)
+        prec, rec, threshold = precision_recall_curve(y_test, score)
+        f1 = 2 * prec * rec / (rec + prec + 1e-9)
+        best_idx = np.argmax(f1)
+        best_thres = threshold[best_idx]
+
+        tn = 0
+        fn = 0
+        tp = 0
+        fp = 0
+        for i in range(len(y_test)):
+            if y_test[i] == 1.0 and score[i] >= best_thres:
+                tp += 1
+            if y_test[i] == 1.0 and score[i] < best_thres:
+                fn += 1
+            if y_test[i] == 0.0 and score[i] < best_thres:
+                tn += 1
+            if y_test[i] == 0.0 and score[i] >= best_thres:
+                fp += 1
+        print('AUC: {}'.format(auc))
+        print('F1: {}'.format(f1[best_idx]))
+        print('PRECISION: {}'.format(prec[best_idx]))
+        print('RECALL: {}'.format(rec[best_idx]))
+        print('TN: {}'.format(tn))
+        print('FN: {}'.format(fn))
+        print('TP: {}'.format(tp))
+        print('FP: {}'.format(fp))
         return auc, 0.0
 
 
@@ -202,23 +171,26 @@ def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test):
     x_train = (x_train - x_train_mean) / x_train_std
     x_test = (x_test - x_train_mean) / x_train_std
 
-    if dataset == 'cadets-e3':
+    if dataset == 'cadets':
         n_neighbors = 200
     else:
         n_neighbors = 10
-
     nbrs = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=-1)
     nbrs.fit(x_train)
 
     save_dict_path = './eval_result/distance_save_{}.pkl'.format(dataset)
     if not os.path.exists(save_dict_path):
         idx = list(range(x_train.shape[0]))
-        random.shuffle(idx)
+        random.shuffle(x_train[idx][:min(50000, x_train.shape[0])])
+        print(len(x_train[idx][:min(50000, x_train.shape[0])]))
         distances, _ = nbrs.kneighbors(x_train[idx][:min(50000, x_train.shape[0])], n_neighbors=n_neighbors)
         del x_train
         mean_distance = distances.mean()
         del distances
+        print('here')
+        print(len(x_test))
         distances, _ = nbrs.kneighbors(x_test, n_neighbors=n_neighbors)
+        print("yes")
         save_dict = [mean_distance, distances.mean(axis=1)]
         distances = distances.mean(axis=1)
         with open(save_dict_path, 'wb') as f:
@@ -231,8 +203,7 @@ def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test):
     auc = roc_auc_score(y_test, score)
     prec, rec, threshold = precision_recall_curve(y_test, score)
     f1 = 2 * prec * rec / (rec + prec + 1e-9)
-    best_idx = np.argmax(f1)
-
+    best_idx = -1
     for i in range(len(f1)):
         # To repeat peak performance
         if dataset == 'trace' and rec[i] < 0.99979:
@@ -241,7 +212,7 @@ def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test):
         if dataset == 'theia' and rec[i] < 0.99996:
             best_idx = i - 1
             break
-        if dataset == 'cadets-e3' and rec[i] < 0.9976:
+        if dataset == 'cadets' and rec[i] < 0.9976:
             best_idx = i - 1
             break
     best_thres = threshold[best_idx]
@@ -252,11 +223,11 @@ def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test):
     fp = 0
     for i in range(len(y_test)):
         if y_test[i] == 1.0 and score[i] >= best_thres:
-            tp += 1
+            tn += 1
         if y_test[i] == 1.0 and score[i] < best_thres:
             fn += 1
         if y_test[i] == 0.0 and score[i] < best_thres:
-            tn += 1
+            tp += 1
         if y_test[i] == 0.0 and score[i] >= best_thres:
             fp += 1
     print('AUC: {}'.format(auc))
@@ -267,4 +238,4 @@ def evaluate_entity_level_using_knn(dataset, x_train, x_test, y_test):
     print('FN: {}'.format(fn))
     print('TP: {}'.format(tp))
     print('FP: {}'.format(fp))
-    return auc, 0.0, None, None
\ No newline at end of file
+    return auc, 0.0, None, None
diff --git a/model/gat.py b/model/gat.py
old mode 100644
new mode 100755
index 0b2d2daeca4468f356274049f7f4454496a8ac06..c64054f0dc985f37d6c1588afc5793eaf36417a6
--- a/model/gat.py
+++ b/model/gat.py
@@ -34,6 +34,7 @@ class GAT(nn.Module):
         last_activation = create_activation(activation) if encoding else None
         last_residual = (encoding and residual)
         last_norm = norm if encoding else None
+
         if self.n_layers == 1:
             self.gats.append(GATConv(
                 n_dim, e_dim, out_dim, n_heads_out, feat_drop, attn_drop, negative_slope,
diff --git a/model/loss_func.py b/model/loss_func.py
old mode 100644
new mode 100755
diff --git a/model/mlp.py b/model/mlp.py
new file mode 100755
index 0000000000000000000000000000000000000000..e1ee9581a033acdbec3b4501cdd45c2d72e7efd9
--- /dev/null
+++ b/model/mlp.py
@@ -0,0 +1,13 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MLP(nn.Module):
+    def __init__(self, d_model, d_ff, dropout=0.1):
+        super(MLP, self).__init__()
+        self.w_1 = nn.Linear(d_model, d_ff)
+        self.w_2 = nn.Linear(d_ff, d_model)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, x):
+        return self.w_2(self.dropout(F.relu(self.w_1(x))))
diff --git a/model/model - Copie.py b/model/model - Copie.py
deleted file mode 100644
index 9bd76a77dea36485ce923267908cca624db6ae65..0000000000000000000000000000000000000000
--- a/model/model - Copie.py	
+++ /dev/null
@@ -1,111 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from utils.poolers import Pooling
-from dgl.nn import EdgeGATConv, GlobalAttentionPooling
-from torch.nn import GRUCell
-import dgl
-
-
-
-
-
-
-class GNN_DTDG(nn.Module):
-    def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device, mlp_layers, number_snapshot):
-        super(GNN_DTDG, self).__init__()
-        self.encoder = GNN_RNN(n_dim, e_dim, hidden_dim, out_dim ,n_layers, n_heads, device, number_snapshot)
-        self.decoder = GNN_RNN(out_dim, e_dim, hidden_dim, n_dim ,n_layers, n_heads, device, number_snapshot)
-        self.number_snapshot = number_snapshot
-        self.classifier_layers = nn.ModuleList([
-        ])
-
-        for _ in range(mlp_layers - 1):
-            self.classifier_layers.extend([
-                nn.Linear(out_dim, out_dim).to(device),
-                nn.ReLU(),
-            ])
-        self.classifier_layers.extend([
-            nn.Linear(out_dim, 1).to(device),
-            nn.Sigmoid()
-        ])
-        
-        self.pooling_gate_nn = nn.Linear(out_dim , 1)
-        self.pooling = GlobalAttentionPooling(self.pooling_gate_nn)
-        self.pooler = Pooling("mean")
-        self.encoder_to_decoder = nn.Linear( out_dim, out_dim, bias=False)
-        
-    def forward(self, g):   
-        encoded = self.encoder(g)
-        new_g = []
-        i=  0
-        for G in g:
-            g_encoded = G.clone()
-            g_encoded.ndata["attr"] = self.encoder_to_decoder(encoded[i])
-            new_g.append(g_encoded)
-            i+=1
-
-        
-
-        decoded = self.decoder(new_g)
-        return decoded[-1]
-       # x = self.pooler(G, embeddings, [1])[0] 
-       # h_g = x.clone()
-      #  for layer in self.classifier_layers:
-       #     x = layer(x)
-    
-    def embed(self, g):
-        return self.encoder(g)[-1]
-
-
-
-class GNN_RNN(nn.Module):
-    def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device,  number_snapshot):
-        super(GNN_RNN, self).__init__()
-        self.device = device
-        self.gnn_layers = nn.ModuleList([EdgeGATConv(in_feats=n_dim, edge_feats=e_dim, out_feats=out_dim, num_heads=n_heads, allow_zero_in_degree=True).to(device)])
-        
-        self.out_dim = out_dim
-        
-        for _ in range(n_layers-1):
-            self.gnn_layers.append(
-            EdgeGATConv(in_feats=out_dim, edge_feats=e_dim, out_feats=out_dim, num_heads=n_heads, allow_zero_in_degree=True).to(device)
-            )
-        
-        self.rnn_layers = nn.ModuleList([])
-
-        for _ in range(number_snapshot):
-            self.rnn_layers.append(
-                GRUCell(out_dim, out_dim, device = device)
-            )
-        
-        self.classifier_layers = nn.ModuleList([
-        ])
-
-
-    def forward(self, g):        
-        i = 0
-        H_s = []
-        for G in g:
-            
-            with G.to(self.device).local_scope():
-                x = G.ndata["attr"].float()
-                e = G.edata["attr"].float()
-                for layer in self.gnn_layers:
-                    r = layer(G, x, e)
-                    x = torch.mean(r,dim=1).to(self.device)
-                    del r
-
-                #if ( i == 0):
-               #     H = self.rnn_layers[i](x, x)
-               # else:
-              #      H = self.rnn_layers[i](x, H)
-                
-                H = x
-                H_s.append(H)
-                embeddings = H.clone()
-                i+=1
-                #x = self.pooling(g[0], x)[0]
-
-        
-        return H_s
\ No newline at end of file
diff --git a/model/model.py b/model/model.py
deleted file mode 100644
index 0389eeae588c10742143ba62a97e9b8f5639eb47..0000000000000000000000000000000000000000
--- a/model/model.py
+++ /dev/null
@@ -1,157 +0,0 @@
-import torch
-import torch.nn as nn 
-from utils.poolers import Pooling
-from .loss_func import sce_loss
-from .gat import GAT
-from .rnn import RNN_Cells
-from utils.utils import create_norm
-from functools import partial
-
-
-class STGNN_AutoEncoder(nn.Module):
-    def __init__(self, n_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, device, number_snapshot, activation, feat_drop, negative_slope, residual, norm, pooling, loss_fn="sce", alpha_l=2, use_all_hidden = True):
-        super(STGNN_AutoEncoder, self).__init__()
-
-        #Initialize the encoder and decoder structure
-        self.encoder = STGNN(n_dim, e_dim, out_dim, out_dim, n_layers, n_heads, n_heads, number_snapshot, activation, feat_drop, negative_slope, residual, norm, True, use_all_hidden, device)
-        self.decoder = STGNN(out_dim, e_dim, out_dim, n_dim, 1, n_heads, 1, number_snapshot, activation, feat_drop, negative_slope, residual, norm, False, False, device)
-        
-
-        # Linear layer for mapping encoder output to decoder input
-        if (use_all_hidden):
-            self.encoder_to_decoder = nn.Linear(n_layers * out_dim, out_dim, bias=False)
-        else:
-            self.encoder_to_decoder = nn.Linear(out_dim, out_dim, bias=False)
-
-            # Additional components and parameters
-        self.n_layers = n_layers
-        self.pooler = Pooling(pooling)
-        self.number_snapshot = number_snapshot
-        self.use_all_hidden = use_all_hidden
-        self.device = device
-        self.criterion = self.setup_loss_fn(loss_fn, alpha_l)
-
-    def setup_loss_fn(self, loss_fn, alpha_l):
-        if loss_fn == "sce":
-            criterion = partial(sce_loss, alpha=alpha_l)
-        
-        elif loss_fn == "ce":
-            criterion = nn.CrossEntropyLoss()
-        elif loss_fn == "mse":
-            criterion = nn.MSELoss()
-        elif loss_fn == "mae":
-            criterion = nn.L1Loss()
-        else:
-            raise NotImplementedError
-        return criterion    
-    
-
-    def forward(self, g):   
-
-        # Encode input graphs
-        node_features = []  
-        new_t = []
-        for G in g:
-            new_g = G.clone()
-            node_features.append(new_g.ndata['attr'].float())
-            new_g.edata['attr'] = new_g.edata['attr'].float()
-            new_t.append(new_g)
-        final_embedding = self.encoder(new_t, node_features)
-        encoding = []
-        if (self.use_all_hidden):
-            for i in range(len(g)):
-                conca = [final_embedding[j][i] for j in range(len(final_embedding))]
-                encoding.append(torch.cat(conca,dim=1))
-        else:
-            encoding = final_embedding[0]
-
-        node_features = []
-        for encoded in encoding:
-            encoded = self.encoder_to_decoder(encoded)
-            node_features.append(encoded)
-
-        reconstructed = self.decoder(new_t, node_features)
-        recon = reconstructed[0][-1]
-        x_init = g[0].ndata['attr'].float()
-        loss = self.criterion(recon, x_init)
-
-        return loss
-    
-    def embed(self, g):
-        node_features= []
-        for G in g:
-            node_features.append(G.ndata['attr'].float())
-    
-        return self.encoder.embed(g, node_features)
-    
-
-class STGNN(nn.Module):
-
-    def __init__(self, input_dim, e_dim, hidden_dim, out_dim, n_layers, n_heads, n_heads_out, n_snapshot, activation, feat_drop, negative_slope, residual, norm, encoding, use_all_hidden, device):
-        super(STGNN, self).__init__()
-
-        if encoding:
-            out = out_dim // n_heads
-            hidden = out_dim // n_heads
-        else:
-            hidden = hidden_dim
-            out = out_dim
-
-        self.gnn = GAT(
-            n_dim=input_dim,
-            e_dim=e_dim,
-            hidden_dim=hidden,
-            out_dim=out,
-            n_layers=n_layers,
-            n_heads=n_heads,
-            n_heads_out=n_heads_out,
-            concat_out=True,
-            activation=activation,
-            feat_drop=feat_drop,
-            attn_drop=0.0,
-            negative_slope=negative_slope,
-            residual=residual,
-            norm=create_norm(norm),
-            encoding=encoding,
-        )
-        self.rnn = RNN_Cells(out_dim, out_dim, n_snapshot, device)
-        self.use_all_hidden = use_all_hidden
-
-    def forward(self, G, node_features):
-
-        embeddings = []
-        for i in range(len(G)):
-            g = G[i]
-            if (self.use_all_hidden):
-                node_embedding, all_hidden = self.gnn(g, node_features[i], return_hidden = self.use_all_hidden)
-                embeddings.append(all_hidden)
-                n_iter = len(all_hidden)
-
-            else:
-                embeddings.append(self.gnn(g, node_features[i], return_hidden = self.use_all_hidden))
-                n_iter = 1
-            
-        result = []
-        for j in range(n_iter):
-            encoding = []
-
-            for embedding in embeddings :
-                if (self.use_all_hidden):
-                    encoding.append(embedding[j])
-                else:
-                    encoding.append(embedding)
-
-            result.append(self.rnn(encoding))
-
-        return result
-    
-
-    def embed(self, G, node_features):
-        embeddings = []
-        for i in range(len(G)):
-            g = G[i].clone()
-            g.edata['attr'] = g.edata['attr'].float()
-            embedding = self.gnn(g, node_features[i], return_hidden = False)
-            embeddings.append(embedding)
-
-        return self.rnn(embeddings)[-1]
\ No newline at end of file
diff --git a/model/rnn.py b/model/rnn.py
deleted file mode 100644
index 663634fbdfbba966894f74241faed35876657f5c..0000000000000000000000000000000000000000
--- a/model/rnn.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from torch.nn import GRUCell
-import torch.nn as nn 
-
-
-class RNN_Cells(nn.Module):   
-    def __init__(self, input_dim, hidden_dim, n_cells, device) :
-        super(RNN_Cells, self).__init__()
-        self.cells = nn.ModuleList()
-        
-        for i in range(n_cells):
-            self.cells.append(GRUCell(input_dim, hidden_dim, device=device))
-
-
-    def forward(self, inputs):
-
-        results = []
-        for i in range(len(self.cells)):
-            if (i == 0):
-                results.append(self.cells[i](inputs[i], inputs[i]))
-            else:
-                results.append(self.cells[i](inputs[i], results[i-1]))
-
-        return results
\ No newline at end of file
diff --git a/model/train.py b/model/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..9ec20ff2f7ec83a695904dff381cb10191d72bfc
--- /dev/null
+++ b/model/train.py
@@ -0,0 +1,28 @@
+import dgl
+import numpy as np
+from tqdm import tqdm
+from utils.loaddata import transform_graph
+
+
+def batch_level_train(model, graphs, train_loader, optimizer, max_epoch, device, n_dim=0, e_dim=0):
+    epoch_iter = tqdm(range(max_epoch))
+    cpt = 1
+    for epoch in epoch_iter:
+        model.train()
+        loss_list = []
+        cpt = 0
+        for _, batch in enumerate(train_loader):
+            cpt +=1
+            print(cpt)
+            batch_g = [transform_graph(graphs[idx][0], n_dim, e_dim).to(device) for idx in batch]
+            batch_g = dgl.batch(batch_g)
+            model.train()
+            loss = model(batch_g)
+
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+            loss_list.append(loss.item())
+            del batch_g
+        epoch_iter.set_description(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f}")
+    return model
diff --git a/model/train_entity.py b/model/train_entity.py
deleted file mode 100644
index 58aa15140e6204fbc48d865036ea37eae251858a..0000000000000000000000000000000000000000
--- a/model/train_entity.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from tqdm import tqdm
-from utils.dataloader import  load_entity_level_dataset
-import torch
-
-
-
-
-def entity_level_train(model,  snapshot, optimizer, max_epoch, device,  dataset_name, train_data):
-    
-        model = model.to(device)
-        model.train()
-        epoch_iter = tqdm(range(max_epoch))
-        for epoch in epoch_iter:
-            epoch_loss = 0.0
-            for i in train_data:
-                g = load_entity_level_dataset(dataset_name, 'train', i, snapshot, device)
-                model.train()
-                loss = model(g)
-                loss /= len(train_data)
-                optimizer.zero_grad()
-                epoch_loss += loss.item()
-                loss.backward()
-                optimizer.step()
-                
-                del g
-            epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
-        torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name))
-
-        return model
\ No newline at end of file
diff --git a/model/train_graph.py b/model/train_graph.py
deleted file mode 100644
index 1ffbd9869fab466b92497c4815db704f9adc1727..0000000000000000000000000000000000000000
--- a/model/train_graph.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import numpy as np
-from tqdm import tqdm
-import torch
-from utils.dataloader import load_graph
-import matplotlib.pyplot as plt
-from model.eval import batch_level_evaluation
-
-def batch_level_train(model,  train_loader, optimizer, max_epoch, device, n_dim, e_dim, dataset_name, validation=True):
-    
-    epoch_iter = tqdm(range(max_epoch))
-    model.to(device)  # Move model to GPU
-    n_epoch = 0
-    validation_f1 = []
-    loss_global = []
-    for epoch in epoch_iter:
-        model.train()
-        loss_list = []
-        for iter, batch in enumerate(train_loader):
-            batch_g = [load_graph(int(idx), dataset_name, device) for idx in batch]  # Move data to GPU
-            model.train()
-            g = batch_g[0]
-            loss = model(g) 
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
-            loss_list.append(loss.item())
-            del batch_g, g
-
-        n_epoch +=1
-
-        if (validation):
-            validation_f1.append(batch_level_evaluation(model, model.pooler, device, ['knn'], dataset_name, n_dim, e_dim)[0])
-
-        loss_global.append(np.mean(loss_list))
-        torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name))
-        epoch_iter.set_description(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f}")
-
-    if (validation):
-        plt.plot(list(range(n_epoch)), validation_f1, label='Graph 2', marker='o', linestyle='-')
-        plt.plot(list(range(n_epoch)), loss_global, label='Graph 2', marker='x', linestyle='--')
-        plt.xlabel('X-axis')
-        plt.ylabel('Y-axis')
-        plt.title('Two Graphs on the Same Plot')
-
-        # Add a legend
-        plt.legend()
-
-        # Display the plot
-        plt.show()
-    return model
diff --git a/mpi_host_file b/mpi_host_file
old mode 100644
new mode 100755
diff --git a/requirement.txt b/requirement.txt
old mode 100644
new mode 100755
diff --git a/result/FedAvg-SC2.pt b/result/FedAvg-SC2.pt
deleted file mode 100644
index 89411388841175b247e424a74463231d45fe1391..0000000000000000000000000000000000000000
Binary files a/result/FedAvg-SC2.pt and /dev/null differ
diff --git a/result/FedAvg-Unicorn-Cadets.pt b/result/FedAvg-Unicorn-Cadets.pt
deleted file mode 100644
index 161ba7a56e45a3d1c33bed20b12bf8564b50a8c8..0000000000000000000000000000000000000000
Binary files a/result/FedAvg-Unicorn-Cadets.pt and /dev/null differ
diff --git a/result/FedAvg-cadets-e3.pt b/result/FedAvg-cadets-e3.pt
deleted file mode 100644
index 4afaa90ee421747fa19d39673f9b63b06eb9804f..0000000000000000000000000000000000000000
Binary files a/result/FedAvg-cadets-e3.pt and /dev/null differ
diff --git a/result/FedAvg-clearscope-e3.pt b/result/FedAvg-clearscope-e3.pt
deleted file mode 100644
index 136dc3c1eef647a32d2189a8896551aa1b5df162..0000000000000000000000000000000000000000
Binary files a/result/FedAvg-clearscope-e3.pt and /dev/null differ
diff --git a/result/FedAvg-streamspot.pt b/result/FedAvg-streamspot.pt
deleted file mode 100644
index 82308fe0b86721ef8790eff5354ec0e74388d70e..0000000000000000000000000000000000000000
Binary files a/result/FedAvg-streamspot.pt and /dev/null differ
diff --git a/result/FedAvg-theia-e3.pt b/result/FedAvg-theia-e3.pt
deleted file mode 100644
index 5a995ad213ddd3883f7c8f000095fd6821410d2b..0000000000000000000000000000000000000000
Binary files a/result/FedAvg-theia-e3.pt and /dev/null differ
diff --git a/result/FedAvg-theia.pt b/result/FedAvg-theia.pt
old mode 100644
new mode 100755
diff --git a/result/FedAvg-trace-e3.pt b/result/FedAvg-trace-e3.pt
deleted file mode 100644
index 4dca8e988169b150f8d1692ad751a5bb051e5609..0000000000000000000000000000000000000000
Binary files a/result/FedAvg-trace-e3.pt and /dev/null differ
diff --git a/result/FedAvg-wget-long.pt b/result/FedAvg-wget-long.pt
deleted file mode 100644
index 9e7504eb16446ab46f327bcc30f13d77b051f75e..0000000000000000000000000000000000000000
Binary files a/result/FedAvg-wget-long.pt and /dev/null differ
diff --git a/result/FedAvg-wget.pt b/result/FedAvg-wget.pt
deleted file mode 100644
index 51c0ecdfba6589f6d63433368fb10c5810705765..0000000000000000000000000000000000000000
Binary files a/result/FedAvg-wget.pt and /dev/null differ
diff --git a/result/FedAvg_Streamspot-streamspot.pt b/result/FedAvg_Streamspot-streamspot.pt
old mode 100644
new mode 100755
diff --git a/result/FedOpt-theia.pt b/result/FedOpt-theia.pt
old mode 100644
new mode 100755
diff --git a/result/FedOpt_Streamspot.pt b/result/FedOpt_Streamspot.pt
old mode 100644
new mode 100755
diff --git a/result/FedProx-theia.pt b/result/FedProx-theia.pt
old mode 100644
new mode 100755
diff --git a/result/FedProx_Streamspot.pt b/result/FedProx_Streamspot.pt
old mode 100644
new mode 100755
diff --git a/save_results/distance_save_cadets_FedAvg.pkl b/save_results/distance_save_cadets_FedAvg.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..870244cd32749e587a2b411dc1c13dee18d08781
Binary files /dev/null and b/save_results/distance_save_cadets_FedAvg.pkl differ
diff --git a/save_results/distance_save_cadets_FedOpt.pkl b/save_results/distance_save_cadets_FedOpt.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..9bdff9888dbf1ee8f9fb45f255b54502b4b56d6c
Binary files /dev/null and b/save_results/distance_save_cadets_FedOpt.pkl differ
diff --git a/save_results/distance_save_cadets_FedProx.pkl b/save_results/distance_save_cadets_FedProx.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..8680ed4c84080861d938bc5ea620d9d8589fc39a
Binary files /dev/null and b/save_results/distance_save_cadets_FedProx.pkl differ
diff --git a/save_results/distance_save_theia_FedAvg.pkl b/save_results/distance_save_theia_FedAvg.pkl
new file mode 100755
index 0000000000000000000000000000000000000000..d885a972d533697733d1c953742df203a5c183a9
Binary files /dev/null and b/save_results/distance_save_theia_FedAvg.pkl differ
diff --git a/save_results/distance_save_theia_FedOpt.pkl b/save_results/distance_save_theia_FedOpt.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..3e84da9be1a793612b77573f16b80f4bba3f530b
Binary files /dev/null and b/save_results/distance_save_theia_FedOpt.pkl differ
diff --git a/save_results/distance_save_theia_FedProx.pkl b/save_results/distance_save_theia_FedProx.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..80e772d9f66424b552809cec77dae53a291f83e4
Binary files /dev/null and b/save_results/distance_save_theia_FedProx.pkl differ
diff --git a/save_results/distance_save_trace_FedAvg.pkl b/save_results/distance_save_trace_FedAvg.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..12baf82b2236f724c1c9f65a3adb19838a693c4b
Binary files /dev/null and b/save_results/distance_save_trace_FedAvg.pkl differ
diff --git a/save_results/distance_save_trace_FedOpt.pkl b/save_results/distance_save_trace_FedOpt.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..6d90af670837a994ba274a3a3a2d09ea5b941d31
Binary files /dev/null and b/save_results/distance_save_trace_FedOpt.pkl differ
diff --git a/save_results/distance_save_trace_FedProx.pkl b/save_results/distance_save_trace_FedProx.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..e88de86f4bcb805e0475a41d4de3bbfde5d7353e
Binary files /dev/null and b/save_results/distance_save_trace_FedProx.pkl differ
diff --git a/test.py b/test.py
new file mode 100755
index 0000000000000000000000000000000000000000..c1f6ebe728019b94bfcd9513907c28a9d39eaa81
--- /dev/null
+++ b/test.py
@@ -0,0 +1,15 @@
+from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split
+
+
+(
+        train_data_num,
+        val_data_num,
+        test_data_num,
+        train_data_global,
+        val_data_global,
+        test_data_global,
+        data_local_num_dict,
+        train_data_local_dict,
+        val_data_local_dict,
+        test_data_local_dict,
+    ) = load_partition_data(None, 3, "streamspot") 
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..968181a1cfaa6de944115c62b2b187376c734502
--- /dev/null
+++ b/train.py
@@ -0,0 +1,90 @@
+import os
+import random
+import torch
+import warnings
+from tqdm import tqdm
+from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata, transform_graph
+from model.autoencoder import build_model
+from torch.utils.data.sampler import SubsetRandomSampler
+from dgl.dataloading import GraphDataLoader
+import dgl
+from model.train import batch_level_train
+from utils.utils import set_random_seed, create_optimizer
+from utils.config import build_args
+warnings.filterwarnings('ignore')
+
+
+def extract_dataloaders(entries, batch_size):
+    random.shuffle(entries)
+    train_idx = torch.arange(len(entries))
+    train_sampler = SubsetRandomSampler(train_idx)
+    train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler)
+    return train_loader
+
+
+def main(main_args):
+    device = "cpu"
+    dataset_name =  "trace"
+    if dataset_name == 'streamspot':
+        main_args.num_hidden = 256
+        main_args.max_epoch = 5
+        main_args.num_layers = 4
+    elif dataset_name == 'wget':
+        main_args.num_hidden = 256
+        main_args.max_epoch = 2
+        main_args.num_layers = 4
+    else:
+        main_args["num_hidden"] = 64
+        main_args["max_epoch"] = 50
+        main_args["num_layers"] = 3
+    set_random_seed(0)
+
+    if dataset_name == 'streamspot' or dataset_name == 'wget':
+        if dataset_name == 'streamspot':
+            batch_size = 12
+        else:
+            batch_size = 1
+        dataset = load_batch_level_dataset(dataset_name)
+        n_node_feat = dataset['n_feat']
+        n_edge_feat = dataset['e_feat']
+        graphs = dataset['dataset']
+        train_index = dataset['train_index']
+        main_args.n_dim = n_node_feat
+        main_args.e_dim = n_edge_feat
+        model = build_model(main_args)
+        model = model.to(device)
+        optimizer = create_optimizer(main_args.optimizer, model, main_args.lr, main_args.weight_decay)
+        model = batch_level_train(model, graphs, (extract_dataloaders(train_index, batch_size)),
+                                  optimizer, main_args.max_epoch, device, main_args.n_dim, main_args.e_dim)
+        torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name))
+    else:
+        metadata = load_metadata(dataset_name)
+        main_args["n_dim"] = metadata['node_feature_dim']
+        main_args["e_dim"] = metadata['edge_feature_dim']
+        model = build_model(main_args)
+        model = model.to(device)
+        model.train()
+        optimizer = create_optimizer(main_args["optimizer"], model, main_args["lr"], main_args["weight_decay"])
+        epoch_iter = tqdm(range(main_args["max_epoch"]))
+        n_train = metadata['n_train']
+        for epoch in epoch_iter:
+            epoch_loss = 0.0
+            for i in range(n_train):
+                g = load_entity_level_dataset(dataset_name, 'train', i).to(device)
+                model.train()
+                loss  = model(g)
+                loss /= n_train
+                optimizer.zero_grad()
+                epoch_loss += loss.item()
+                loss.backward()
+                optimizer.step()
+                del g
+            epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
+            torch.save(model.state_dict(), "./result/checkpoint-{}.pt".format(dataset_name))
+
+    return
+
+
+if __name__ == '__main__':
+    args = build_args()
+    main(args)
diff --git a/trainer/__pycache__/magic_aggregator.cpython-310.pyc b/trainer/__pycache__/magic_aggregator.cpython-310.pyc
deleted file mode 100644
index 173a82a3b45a20cb705440e95a4e6d09bf98930b..0000000000000000000000000000000000000000
Binary files a/trainer/__pycache__/magic_aggregator.cpython-310.pyc and /dev/null differ
diff --git a/trainer/__pycache__/magic_aggregator.cpython-311.pyc b/trainer/__pycache__/magic_aggregator.cpython-311.pyc
index e4c63a27702b466871bad2b0f18230d3620cd143..d9f8b782ee6ef9465d777c9e882fb180828f07ea 100644
Binary files a/trainer/__pycache__/magic_aggregator.cpython-311.pyc and b/trainer/__pycache__/magic_aggregator.cpython-311.pyc differ
diff --git a/trainer/__pycache__/magic_trainer.cpython-310.pyc b/trainer/__pycache__/magic_trainer.cpython-310.pyc
deleted file mode 100644
index 5aa1877df83f3deb6862f39988826b001e5273f8..0000000000000000000000000000000000000000
Binary files a/trainer/__pycache__/magic_trainer.cpython-310.pyc and /dev/null differ
diff --git a/trainer/__pycache__/magic_trainer.cpython-311.pyc b/trainer/__pycache__/magic_trainer.cpython-311.pyc
index a42939be3e8c7c9ffc0d3985c7b46efe4ab53d0f..a9ecabf7f8cc81c7ccd9f920ba01e5e3c9436ee9 100644
Binary files a/trainer/__pycache__/magic_trainer.cpython-311.pyc and b/trainer/__pycache__/magic_trainer.cpython-311.pyc differ
diff --git a/trainer/__pycache__/single_trainer.cpython-311.pyc b/trainer/__pycache__/single_trainer.cpython-311.pyc
index 4d343331aa188a304c811aafdd0d1814cc17efe9..5bfbd974b57fc9e9abe1ee459c27d16e5465c94d 100644
Binary files a/trainer/__pycache__/single_trainer.cpython-311.pyc and b/trainer/__pycache__/single_trainer.cpython-311.pyc differ
diff --git a/trainer/magic_aggregator.py b/trainer/magic_aggregator.py
old mode 100644
new mode 100755
index a5d09149bf53e9180f58802853aa4e6eca33c438..a9d7f70c78d64347b655494de696050e82d0fa6f
--- a/trainer/magic_aggregator.py
+++ b/trainer/magic_aggregator.py
@@ -6,15 +6,11 @@ import wandb
 from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
 from utils.config import build_args
 from fedml.core import ServerAggregator
-from model.train_graph import batch_level_train
-from model.train_entity import entity_level_train
 from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
-
-
-from utils.utils import set_random_seed, create_optimizer
 from utils.poolers import Pooling
-from utils.config import build_args
-from utils.dataloader import load_data, load_entity_level_dataset, load_metadata
+# Trainer for MoleculeNet. The evaluation metric is ROC-AUC
+from data.data_loader import load_batch_level_dataset_main, load_metadata, load_entity_level_dataset
+from utils.loaddata import  load_batch_level_dataset
 
 class MagicWgetAggregator(ServerAggregator):
     def __init__(self, model, args, name):
@@ -66,30 +62,25 @@ class MagicWgetAggregator(ServerAggregator):
             logging.info("Models match perfectly! :)")
 
     def _test(self, test_data, device, args):
-        main_args = build_args()        
-        dataset_name = self.name
-        nsnapshot = args.snapshot
-        if (self.name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']):
-            logging.info("----------test--------") 
-            set_random_seed(0)
-            dataset = load_data(dataset_name)
-            n_node_feat = dataset['n_feat']
-            n_edge_feat = dataset['e_feat']
-                #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148]
-            main_args.n_dim = n_node_feat
-            main_args.e_dim = n_edge_feat
+        args = build_args()           
+        if (self.name == 'wget' or self.name == 'streamspot'):
+            logging.info("----------test--------")
+                 
             model = self.model
             model.eval()
             model.to(device)
-            model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device))
-            model = model.to(device)
-            pooler = Pooling(main_args.pooling)
-            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], dataset_name, main_args.n_dim,
-                                                                    main_args.e_dim)
+            pooler = Pooling(args["pooling"])
+            dataset = load_batch_level_dataset(self.name)
+            n_node_feat = dataset['n_feat']
+            n_edge_feat = dataset['e_feat']
+            args["n_dim"] = n_node_feat
+            args["e_dim"] = n_edge_feat
+            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"],  args["e_dim"])
         else:
+            torch.save(self.model.state_dict(), "./result/FedAvg-4client-{}.pt".format(self.name))
             metadata = load_metadata(self.name)
-            main_args.n_dim = metadata['node_feature_dim']
-            main_args.e_dim = metadata['edge_feature_dim']
+            args["n_dim"] = metadata['node_feature_dim']
+            args["e_dim"] = metadata['edge_feature_dim']
             model = self.model.to(device)
             model.eval()
             malicious, _ = metadata['malicious']
@@ -99,17 +90,17 @@ class MagicWgetAggregator(ServerAggregator):
             with torch.no_grad():
                 x_train = []
                 for i in range(n_train):
-                   g = load_entity_level_dataset(dataset_name, 'train', i, nsnapshot, device)
+                   g = load_entity_level_dataset(self.name, 'train', i).to(device)
                    x_train.append(model.embed(g).cpu().detach().numpy())
                    del g
                 x_train = np.concatenate(x_train, axis=0)
                 skip_benign = 0
             x_test = []
             for i in range(n_test):
-                g = load_entity_level_dataset(self.name, 'test', i, nsnapshot, device)
+                g = load_entity_level_dataset(self.name, 'test', i).to(device)
                 # Exclude training samples from the test set
                 if i != n_test - 1:
-                    skip_benign += g[0].number_of_nodes()
+                    skip_benign += g.number_of_nodes()
                 x_test.append(model.embed(g).cpu().detach().numpy())
                 del g
             x_test = np.concatenate(x_test, axis=0)
diff --git a/trainer/magic_trainer.py b/trainer/magic_trainer.py
old mode 100644
new mode 100755
index 02c428c2459ec140dee5be82bbab64fcf0532c33..99c9a5b49e31a631f30053db0c3cf0623569b2eb
--- a/trainer/magic_trainer.py
+++ b/trainer/magic_trainer.py
@@ -10,18 +10,18 @@ import wandb
 from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
 
 from fedml.core import ClientTrainer
+from model.autoencoder import build_model
 from torch.utils.data.sampler import SubsetRandomSampler
 from dgl.dataloading import GraphDataLoader
-
-from model.train_graph import batch_level_train
-from model.train_entity import entity_level_train
-from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
-
-
+from model.train import batch_level_train
 from utils.utils import set_random_seed, create_optimizer
 from utils.poolers import Pooling
 from utils.config import build_args
-from utils.dataloader import load_data, load_entity_level_dataset, load_metadata
+from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata
+from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
+from model.autoencoder import build_model
+from data.data_loader import transform_data
+from utils.loaddata import load_batch_level_dataset
 
 # Trainer for MoleculeNet. The evaluation metric is ROC-AUC
 def extract_dataloaders(entries, batch_size):
@@ -47,47 +47,77 @@ class MagicTrainer(ClientTrainer):
         self.model.load_state_dict(model_parameters)
 
     def train(self, train_data, device, args):
-        main_args = build_args()
-        dataset_name = self.name
-        input('start')
-        if (dataset_name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']):
-            if (dataset_name == 'wget' or dataset_name == 'streamspot' or dataset_name == 'Unicorn-Cadets'  or dataset_name == 'clearscope-e3'):
-                batch_size = 1
-                main_args.max_epoch = 6
-
-            elif (dataset_name == 'SC2' or dataset_name == 'wget-long'):
-                batch_size = 1
-                main_args.max_epoch = 1
- 
+        test_data = None
+        args = build_args()
+        if (self.name == "wget"):
+            args["num_hidden"] = 256
+            args["max_epoch"] = 2
+            args["num_layers"] = 4
+            batch_size = 1
+        elif (self.name == "streamspot"):
+            args["num_hidden"] = 256
+            args["max_epoch"] = 5
+            args["num_layers"] = 4
+            batch_size = 12
+        else:
+            args["num_hidden"] = 64
+            args["max_epoch"] = 50
+            args["num_layers"] = 3
         
-            dataset = load_data(dataset_name)
+        max_test_score = 0
+        best_model_params = {}       
+        if (self.name == 'wget' or self.name == 'streamspot'):
+ 
+            dataset = load_batch_level_dataset(self.name)
+            data = transform_data(train_data)
             n_node_feat = dataset['n_feat']
             n_edge_feat = dataset['e_feat']
-            #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148]
-            validation_index = dataset['validation_index']
-            label = dataset["labels"]
-            main_args.n_dim = n_node_feat
-            main_args.e_dim = n_edge_feat
-            main_args.optimizer = "adamw"
-            set_random_seed(0)
-            #model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device))
-            optimizer = create_optimizer(main_args.optimizer, self.model, main_args.lr, main_args.weight_decay)
-            train_loader = extract_dataloaders(train_data[0], batch_size)
-            self.model = batch_level_train(self.model,  train_loader, optimizer, main_args.max_epoch, device, main_args.n_dim, main_args.e_dim, dataset_name, validation= False)
+            graphs = data['dataset']
+            train_index = data['train_index']
+            args["n_dim"] = n_node_feat
+            args["e_dim"] = n_edge_feat
+        #self.model = build_model(args)
+            self.model = self.model.to(device)
+            optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"])
+            self.model = batch_level_train(self.model, graphs, (extract_dataloaders(train_index, batch_size)),
+                                  optimizer, args["max_epoch"], device, n_node_feat,   n_edge_feat)
+            test_score, _ = self.test(test_data, device, args)
         else:
-            main_args.max_epoch = 50            
-            nsnapshot = args.snapshot
-            main_args.optimizer = "adam"
-            set_random_seed(0)
-            #model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device))
-            optimizer = create_optimizer(main_args.optimizer, self.model, main_args.lr, main_args.weight_decay)
-            #train_loader = extract_dataloaders(train_index, batch_size)
-            self.model = entity_level_train(self.model, nsnapshot, optimizer, main_args.max_epoch, device, dataset_name, train_data[0])
-
-        self.max = 0
-        best_model_params = {
+            
+            metadata = load_metadata(self.name)
+            args["n_dim"] = metadata['node_feature_dim']
+            args["e_dim"] = metadata['edge_feature_dim']
+            self.model = self.model.to(device)
+            self.model.train()
+            optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"])
+            epoch_iter = tqdm(range(args["max_epoch"]))
+            n_train = len(train_data[0])
+            input("start?")
+            for epoch in epoch_iter:
+                epoch_loss = 0.0
+                for i in range(n_train):
+                    g = train_data[0][i]
+                    self.model.train()
+                    loss  = self.model(g)
+                    loss /= n_train
+                    optimizer.zero_grad()
+                    epoch_loss += loss.item()
+                    loss.backward(retain_graph=True)
+                    optimizer.step()
+                    del g
+            epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
+        if (self.name == 'wget' or self.name == 'streamspot'):    
+            test_score, _ = self.test(test_data, device, args)
+            if test_score > self.max:
+                self.max = test_score
+                best_model_params = {
+                k: v.cpu() for k, v in self.model.state_dict().items()
+                }
+        else:
+            self.max = 0
+            best_model_params = {
             k: v.cpu() for k, v in self.model.state_dict().items()
-        }
+            }
             
      
 
@@ -96,31 +126,17 @@ class MagicTrainer(ClientTrainer):
     def test(self, test_data, device, args):
         if (self.name == 'wget' or self.name == 'streamspot'):
             logging.info("----------test--------")
-            main_args = build_args()        
-            dataset_name = self.name
-            if dataset_name in ['streamspot', 'wget', 'SC2']:
-                main_args.num_hidden = 256
-                main_args.num_layers = 4
-                main_args.max_epoch = 50
-            else:
-                main_args.num_hidden = 64
-                main_args.num_layers = 3
-            set_random_seed(0)
-            dataset = load_data(dataset_name, 1, 0.6, 0.2)
-            n_node_feat = dataset['n_feat']
-            n_edge_feat = dataset['e_feat']
-                #train_index = [104, 118, 86, 74, 16, 12, 117, 108, 59, 146, 97, 49, 107, 47, 23, 111, 32, 124, 121, 119, 141, 50, 43, 98, 73, 80, 4, 140, 1, 17, 55, 136, 95, 120, 103, 94, 34, 68, 130, 26, 30, 29, 129, 71, 6, 128, 84, 85, 72, 96, 87, 58, 81, 79, 31, 37, 54, 93, 135, 33, 61, 134, 52, 106, 126, 139, 8, 115, 82, 46, 101, 114, 60, 138, 132, 5, 2, 19, 143, 77, 92, 123, 42, 113, 125, 15, 105, 14, 145, 148]
-            main_args.n_dim = n_node_feat
-            main_args.e_dim = n_edge_feat
+            args = build_args()        
             model = self.model
             model.eval()
             model.to(device)
-            model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device))
-            model = model.to(device)
-            pooler = Pooling(main_args.pooling)
-            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], dataset_name, main_args.n_dim,
-                                                                    main_args.e_dim)
-            
+            pooler = Pooling(args["pooling"])
+            dataset = load_batch_level_dataset(self.name)
+            n_node_feat = dataset['n_feat']
+            n_edge_feat = dataset['e_feat']
+            args["n_dim"] = n_node_feat
+            args["e_dim"] = n_edge_feat
+            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"],  args["e_dim"])
         else:
             metadata = load_metadata(self.name)
             args["n_dim"] = metadata['node_feature_dim']
diff --git a/trainer/single_trainer.py b/trainer/single_trainer.py
old mode 100644
new mode 100755
diff --git a/trainer/utils/__pycache__/config.cpython-311.pyc b/trainer/utils/__pycache__/config.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..233d940b3ec7d7070e2ee889729ac7592675d3bf
Binary files /dev/null and b/trainer/utils/__pycache__/config.cpython-311.pyc differ
diff --git a/trainer/utils/config.py b/trainer/utils/config.py
new file mode 100755
index 0000000000000000000000000000000000000000..bb9c1ef28d385b80c7673027b4f69b31c36b1175
--- /dev/null
+++ b/trainer/utils/config.py
@@ -0,0 +1,20 @@
+import argparse
+
+
+def build_args():
+    parser = argparse.ArgumentParser(description="MAGIC")
+    parser.add_argument("--dataset", type=str, default="wget")
+    parser.add_argument("--device", type=int, default=-1)
+    parser.add_argument("--lr", type=float, default=0.001,
+                        help="learning rate")
+    parser.add_argument("--weight_decay", type=float, default=5e-4,
+                        help="weight decay")
+    parser.add_argument("--negative_slope", type=float, default=0.2,
+                        help="the negative slope of leaky relu for GAT")
+    parser.add_argument("--mask_rate", type=float, default=0.5)
+    parser.add_argument("--alpha_l", type=float, default=3, help="`pow`inddex for `sce` loss")
+    parser.add_argument("--optimizer", type=str, default="adam")
+    parser.add_argument("--loss_fn", type=str, default='sce')
+    parser.add_argument("--pooling", type=str, default="mean")
+    args = parser.parse_args()
+    return args
diff --git a/trainer/utils/loaddata.py b/trainer/utils/loaddata.py
new file mode 100755
index 0000000000000000000000000000000000000000..41e7dfc03ee39adcc1f5729d59aa21124d981fff
--- /dev/null
+++ b/trainer/utils/loaddata.py
@@ -0,0 +1,197 @@
+import pickle as pkl
+import time
+import torch.nn.functional as F
+import dgl
+import networkx as nx
+import json
+from tqdm import tqdm
+import os
+
+
+class StreamspotDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(StreamspotDataset, self).__init__(name=name)
+        if name == 'streamspot':
+            path = './data/streamspot'
+            num_graphs = 600
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                if 300 <= idx <= 399:
+                    self.labels.append(1)
+                else:
+                    self.labels.append(0)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+
+class WgetDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(WgetDataset, self).__init__(name=name)
+        if name == 'wget':
+            path = './data/wget/final'
+            num_graphs = 150
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                if 0 <= idx <= 24:
+                    self.labels.append(1)
+                else:
+                    self.labels.append(0)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+
+def load_rawdata(name):
+    if name == 'streamspot':
+        path = './data/streamspot'
+        if os.path.exists(path + '/graphs.pkl'):
+            print('Loading processed {} dataset...'.format(name))
+            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
+        else:
+            raw_data = StreamspotDataset(name)
+            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
+    elif name == 'wget':
+        path = './data/wget'
+        if os.path.exists(path + '/graphs.pkl'):
+            print('Loading processed {} dataset...'.format(name))
+            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
+        else:
+            raw_data = WgetDataset(name)
+            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
+    else:
+        raise NotImplementedError
+    return raw_data
+
+
+def load_batch_level_dataset(dataset_name):
+    dataset = load_rawdata(dataset_name)
+    graph, _ = dataset[0]
+    node_feature_dim = 0
+    for g, _ in dataset:
+        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
+    edge_feature_dim = 0
+    for g, _ in dataset:
+        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
+    node_feature_dim += 1
+    edge_feature_dim += 1
+    full_dataset = [i for i in range(len(dataset))]
+    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
+    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
+
+    return {'dataset': dataset,
+            'train_index': train_dataset,
+            'full_index': full_dataset,
+            'n_feat': node_feature_dim,
+            'e_feat': edge_feature_dim}
+
+
+def transform_graph(g, node_feature_dim, edge_feature_dim):
+    new_g = g.clone()
+    new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float()
+    new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float()
+    return new_g
+
+
+def preload_entity_level_dataset(path):
+    path = './data/' + path
+    if os.path.exists(path + '/metadata.json'):
+        pass
+    else:
+        print('transforming')
+        train_gs = [dgl.from_networkx(
+            nx.node_link_graph(g),
+            node_attrs=['type'],
+            edge_attrs=['type']
+        ) for g in pkl.load(open(path + '/train.pkl', 'rb'))]
+        print('transforming')
+        test_gs = [dgl.from_networkx(
+            nx.node_link_graph(g),
+            node_attrs=['type'],
+            edge_attrs=['type']
+        ) for g in pkl.load(open(path + '/test.pkl', 'rb'))]
+        malicious = pkl.load(open(path + '/malicious.pkl', 'rb'))
+
+        node_feature_dim = 0
+        for g in train_gs:
+            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
+        for g in test_gs:
+            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
+        node_feature_dim += 1
+        edge_feature_dim = 0
+        for g in train_gs:
+            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
+        for g in test_gs:
+            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
+        edge_feature_dim += 1
+        result_test_gs = []
+        for g in test_gs:
+            g = transform_graph(g, node_feature_dim, edge_feature_dim)
+            result_test_gs.append(g)
+        result_train_gs = []
+        for g in train_gs:
+            g = transform_graph(g, node_feature_dim, edge_feature_dim)
+            result_train_gs.append(g)
+        metadata = {
+            'node_feature_dim': node_feature_dim,
+            'edge_feature_dim': edge_feature_dim,
+            'malicious': malicious,
+            'n_train': len(result_train_gs),
+            'n_test': len(result_test_gs)
+        }
+        with open(path + '/metadata.json', 'w', encoding='utf-8') as f:
+            json.dump(metadata, f)
+        for i, g in enumerate(result_train_gs):
+            with open(path + '/train{}.pkl'.format(i), 'wb') as f:
+                pkl.dump(g, f)
+        for i, g in enumerate(result_test_gs):
+            with open(path + '/test{}.pkl'.format(i), 'wb') as f:
+                pkl.dump(g, f)
+
+
+def load_metadata(path):
+    preload_entity_level_dataset(path)
+    with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f:
+        metadata = json.load(f)
+    return metadata
+
+
+def load_entity_level_dataset(path, t, n):
+    preload_entity_level_dataset(path)
+    with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f:
+        data = pkl.load(f)
+    return data
diff --git a/trainer/utils/poolers.py b/trainer/utils/poolers.py
new file mode 100755
index 0000000000000000000000000000000000000000..4b9dc4aee1eac42f578952596a5e400adbd1e391
--- /dev/null
+++ b/trainer/utils/poolers.py
@@ -0,0 +1,43 @@
+import torch.nn as nn
+
+
+class Pooling(nn.Module):
+    def __init__(self, pooler):
+        super(Pooling, self).__init__()
+        self.pooler = pooler
+
+    def forward(self, graph, feat, t=None):
+        feat = feat
+        # Implement node type-specific pooling
+        with graph.local_scope():
+            if t is None:
+                if self.pooler == 'mean':
+                    return feat.mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat.sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat.max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
+            elif isinstance(t, int):
+                mask = (graph.ndata['type'] == t)
+                if self.pooler == 'mean':
+                    return feat[mask].mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat[mask].sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat[mask].max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
+            else:
+                mask = (graph.ndata['type'] == t[0])
+                for i in range(1, len(t)):
+                    mask |= (graph.ndata['type'] == t[i])
+                if self.pooler == 'mean':
+                    return feat[mask].mean(0, keepdim=True)
+                elif self.pooler == 'sum':
+                    return feat[mask].sum(0, keepdim=True)
+                elif self.pooler == 'max':
+                    return feat[mask].max(0, keepdim=True)
+                else:
+                    raise NotImplementedError
diff --git a/trainer/utils/streamspot_parser.py b/trainer/utils/streamspot_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..04438eb0f6336ae2bad2015ad6caf4919f48f9a3
--- /dev/null
+++ b/trainer/utils/streamspot_parser.py
@@ -0,0 +1,53 @@
+import networkx as nx
+from tqdm import tqdm
+import json
+raw_path = '../data/streamspot/'
+
+NUM_GRAPHS = 600
+node_type_dict = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
+edge_type_dict = ['i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
+                  'q', 't', 'u', 'v', 'w', 'y', 'z', 'A', 'C', 'D', 'E', 'G']
+node_type_set = set(node_type_dict)
+edge_type_set = set(edge_type_dict)
+
+count_graph = 0
+with open(raw_path + 'all.tsv', 'r', encoding='utf-8') as f:
+    lines = f.readlines()
+    g = nx.DiGraph()
+    node_map = {}
+    count_node = 0
+    for line in tqdm(lines):
+        src, src_type, dst, dst_type, etype, graph_id = line.strip('\n').split('\t')
+        graph_id = int(graph_id)
+        if src_type not in node_type_set or dst_type not in node_type_set:
+            continue
+        if etype not in edge_type_set:
+            continue
+        if graph_id != count_graph:
+            count_graph += 1
+            for n in g.nodes():
+                g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
+            for e in g.edges():
+                g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
+            f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
+            json.dump(nx.node_link_data(g), f1)
+            assert graph_id == count_graph
+            g = nx.DiGraph()
+            count_node = 0
+        if src not in node_map:
+            node_map[src] = count_node
+            g.add_node(count_node, type=src_type)
+            count_node += 1
+        if dst not in node_map:
+            node_map[dst] = count_node
+            g.add_node(count_node, type=dst_type)
+            count_node += 1
+        if not g.has_edge(node_map[src], node_map[dst]):
+            g.add_edge(node_map[src], node_map[dst], type=etype)
+    count_graph += 1
+    for n in g.nodes():
+        g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
+    for e in g.edges():
+        g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
+    f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
+    json.dump(nx.node_link_data(g), f1)
diff --git a/trainer/utils/trace_parser.py b/trainer/utils/trace_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..76a91d424a3eea0bbec87ffb635a0d98b27d72f7
--- /dev/null
+++ b/trainer/utils/trace_parser.py
@@ -0,0 +1,288 @@
+import argparse
+import json
+import os
+import random
+import re
+
+from tqdm import tqdm
+import networkx as nx
+import pickle as pkl
+
+
+node_type_dict = {}
+edge_type_dict = {}
+node_type_cnt = 0
+edge_type_cnt = 0
+
+metadata = {
+    'trace':{
+        'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'],
+        'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4']
+    },
+    'theia':{
+            'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'],
+            'test': ['ta1-theia-e3-official-6r.json.8']
+    },
+    'cadets':{
+            'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'],
+            'test': ['ta1-cadets-e3-official-2.json']
+    }
+}
+
+
+pattern_uuid = re.compile(r'uuid\":\"(.*?)\"')
+pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_type = re.compile(r'type\":\"(.*?)\"')
+pattern_time = re.compile(r'timestampNanos\":(.*?),')
+pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"')
+pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"')
+pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"')
+
+adversarial = 'None'
+
+
+def read_single_graph(dataset, malicious, path, test=False):
+    global node_type_cnt, edge_type_cnt
+    g = nx.DiGraph()
+    print('converting {} ...'.format(path))
+    path = '../data/{}/'.format(dataset) + path + '.txt'
+    f = open(path, 'r')
+    lines = []
+    for l in f.readlines():
+        split_line = l.split('\t')
+        src, src_type, dst, dst_type, edge_type, ts = split_line
+        ts = int(ts)
+        if not test:
+            if src in malicious or dst in malicious:
+                if src in malicious and src_type != 'MemoryObject':
+                    continue
+                if dst in malicious and dst_type != 'MemoryObject':
+                    continue
+
+        if src_type not in node_type_dict:
+            node_type_dict[src_type] = node_type_cnt
+            node_type_cnt += 1
+        if dst_type not in node_type_dict:
+            node_type_dict[dst_type] = node_type_cnt
+            node_type_cnt += 1
+        if edge_type not in edge_type_dict:
+            edge_type_dict[edge_type] = edge_type_cnt
+            edge_type_cnt += 1
+        if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type:
+            lines.append([dst, src, dst_type, src_type, edge_type, ts])
+        else:
+            lines.append([src, dst, src_type, dst_type, edge_type, ts])
+    lines.sort(key=lambda l: l[5])
+
+    node_map = {}
+    node_type_map = {}
+    node_cnt = 0
+    node_list = []
+    for l in lines:
+        src, dst, src_type, dst_type, edge_type = l[:5]
+        src_type_id = node_type_dict[src_type]
+        dst_type_id = node_type_dict[dst_type]
+        edge_type_id = edge_type_dict[edge_type]
+        if src not in node_map:
+            if test:
+                if adversarial == 'MFE' or adversarial == 'MCE':
+                    if src in malicious:
+                        src_type_id = node_type_dict['FILE_OBJECT_FILE']
+            if not test:
+                if adversarial == 'BFP':
+                    if src not in malicious:
+                        i = random.randint(1, 20)
+                        if i == 1:
+                            src_type_id = node_type_dict['NetFlowObject']
+            node_map[src] = node_cnt
+            g.add_node(node_cnt, type=src_type_id)
+            node_list.append(src)
+            node_type_map[src] = src_type
+            node_cnt += 1
+        if dst not in node_map:
+            if test:
+                if adversarial == 'MFE' or adversarial == 'MCE':
+                    if dst in malicious:
+                        dst_type_id = node_type_dict['FILE_OBJECT_FILE']
+            if not test:
+                if adversarial == 'BFP':
+                    if dst not in malicious:
+                        i = random.randint(1, 20)
+                        if i == 1:
+                            dst_type_id = node_type_dict['NetFlowObject']
+            node_map[dst] = node_cnt
+            g.add_node(node_cnt, type=dst_type_id)
+            node_type_map[dst] = dst_type
+            node_list.append(dst)
+            node_cnt += 1
+        if not g.has_edge(node_map[src], node_map[dst]):
+            g.add_edge(node_map[src], node_map[dst], type=edge_type_id)
+    if (adversarial == 'MSE' or adversarial == 'MCE') and test:
+        for i, node in enumerate(node_list):
+            if node in malicious:
+                while True:
+                    another_node = random.choice(node_list)
+                    if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE'])
+                        print(node, another_node)
+                        break
+                    if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ'])
+                        print(node, another_node)
+                        break
+                    if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT'])
+                        print(node, another_node)
+                        break
+                    if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject':
+                        break
+    return node_map, g
+
+
+def preprocess_dataset(dataset):
+    id_nodetype_map = {}
+    id_nodename_map = {}
+    for file in os.listdir('../data/{}/'.format(dataset)):
+        if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file:
+            print('reading {} ...'.format(file))
+            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
+            for line in tqdm(f):
+                if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue
+                if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue
+                if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue
+                if len(pattern_uuid.findall(line)) == 0: print(line)
+                uuid = pattern_uuid.findall(line)[0]
+                subject_type = pattern_type.findall(line)
+
+                if len(subject_type) < 1:
+                    if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line:
+                        subject_type = 'MemoryObject'
+                    if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line:
+                        subject_type = 'NetFlowObject'
+                    if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line:
+                        subject_type = 'UnnamedPipeObject'
+                else:
+                    subject_type = subject_type[0]
+
+                if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']:
+                    continue
+                id_nodetype_map[uuid] = subject_type
+                if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_file_name.findall(line)[0]
+                elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_process_name.findall(line)[0]
+                elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0]
+    for key in metadata[dataset]:
+        for file in metadata[dataset][key]:
+            if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'):
+                continue
+            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
+            fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8')
+            print('processing {} ...'.format(file))
+            for line in tqdm(f):
+                if 'com.bbn.tc.schema.avro.cdm18.Event' in line:
+                    edgeType = pattern_type.findall(line)[0]
+                    timestamp = pattern_time.findall(line)[0]
+                    srcId = pattern_src.findall(line)
+
+                    if len(srcId) == 0: continue
+                    srcId = srcId[0]
+                    if not srcId in id_nodetype_map:
+                        continue
+                    srcType = id_nodetype_map[srcId]
+                    dstId1 = pattern_dst1.findall(line)
+                    if len(dstId1) > 0 and dstId1[0] != 'null':
+                        dstId1 = dstId1[0]
+                        if not dstId1 in id_nodetype_map:
+                            continue
+                        dstType1 = id_nodetype_map[dstId1]
+                        this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str(
+                            dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
+                        fw.write(this_edge1)
+
+                    dstId2 = pattern_dst2.findall(line)
+                    if len(dstId2) > 0 and dstId2[0] != 'null':
+                        dstId2 = dstId2[0]
+                        if not dstId2 in id_nodetype_map.keys():
+                            continue
+                        dstType2 = id_nodetype_map[dstId2]
+                        this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str(
+                            dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
+                        fw.write(this_edge2)
+            fw.close()
+            f.close()
+    if len(id_nodename_map) != 0:
+        fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8')
+        json.dump(id_nodename_map, fw)
+    if len(id_nodetype_map) != 0:
+        fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8')
+        json.dump(id_nodetype_map, fw)
+
+
+def read_graphs(dataset):
+    malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset)
+    f = open(malicious_entities, 'r')
+    malicious_entities = set()
+    for l in f.readlines():
+        malicious_entities.add(l.lstrip().rstrip())
+
+    preprocess_dataset(dataset)
+    train_gs = []
+    for file in metadata[dataset]['train']:
+        _, train_g = read_single_graph(dataset, malicious_entities, file, False)
+        train_gs.append(train_g)
+    test_gs = []
+    test_node_map = {}
+    count_node = 0
+    for file in metadata[dataset]['test']:
+        node_map, test_g = read_single_graph(dataset, malicious_entities, file, True)
+        assert len(node_map) == test_g.number_of_nodes()
+        test_gs.append(test_g)
+        for key in node_map:
+            if key not in test_node_map:
+                test_node_map[key] = node_map[key] + count_node
+        count_node += test_g.number_of_nodes()
+
+    if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)):
+        with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f:
+            id_nodename_map = json.load(f)
+        with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f:
+            id_nodetype_map = json.load(f)
+        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
+        final_malicious_entities = []
+        malicious_names = []
+        for e in malicious_entities:
+            if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject':
+                final_malicious_entities.append(test_node_map[e])
+                if e in id_nodename_map:
+                    malicious_names.append(id_nodename_map[e])
+                    f.write('{}\t{}\n'.format(e, id_nodename_map[e]))
+                else:
+                    malicious_names.append(e)
+                    f.write('{}\t{}\n'.format(e, e))
+    else:
+        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
+        final_malicious_entities = []
+        malicious_names = []
+        for e in malicious_entities:
+            if e in test_node_map:
+                final_malicious_entities.append(test_node_map[e])
+                malicious_names.append(e)
+                f.write('{}\t{}\n'.format(e, e))
+
+    pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb'))
+    pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb'))
+    pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb'))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='CDM Parser')
+    parser.add_argument("--dataset", type=str, default="trace")
+    args = parser.parse_args()
+    if args.dataset not in ['trace', 'theia', 'cadets']:
+        raise NotImplementedError
+    read_graphs(args.dataset)
+
diff --git a/trainer/utils/utils.py b/trainer/utils/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..bcf9a481407696d73f30fe1dde279154a05702b3
--- /dev/null
+++ b/trainer/utils/utils.py
@@ -0,0 +1,112 @@
+import torch
+import torch.nn as nn
+from functools import partial
+import numpy as np
+import random
+import torch.optim as optim
+
+
+def create_optimizer(opt, model, lr, weight_decay):
+    opt_lower = opt.lower()
+    parameters = model.parameters()
+    opt_args = dict(lr=lr, weight_decay=weight_decay)
+    optimizer = None
+    opt_split = opt_lower.split("_")
+    opt_lower = opt_split[-1]
+    if opt_lower == "adam":
+        optimizer = optim.Adam(parameters, **opt_args)
+    elif opt_lower == "adamw":
+        optimizer = optim.AdamW(parameters, **opt_args)
+    elif opt_lower == "adadelta":
+        optimizer = optim.Adadelta(parameters, **opt_args)
+    elif opt_lower == "radam":
+        optimizer = optim.RAdam(parameters, **opt_args)
+    elif opt_lower == "sgd":
+        opt_args["momentum"] = 0.9
+        return optim.SGD(parameters, **opt_args)
+    else:
+        assert False and "Invalid optimizer"
+    return optimizer
+
+
+def random_shuffle(x, y):
+    idx = list(range(len(x)))
+    random.shuffle(idx)
+    return x[idx], y[idx]
+
+
+def set_random_seed(seed):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    torch.backends.cudnn.determinstic = True
+
+
+def create_activation(name):
+    if name == "relu":
+        return nn.ReLU()
+    elif name == "gelu":
+        return nn.GELU()
+    elif name == "prelu":
+        return nn.PReLU()
+    elif name is None:
+        return nn.Identity()
+    elif name == "elu":
+        return nn.ELU()
+    else:
+        raise NotImplementedError(f"{name} is not implemented.")
+
+
+def create_norm(name):
+    if name == "layernorm":
+        return nn.LayerNorm
+    elif name == "batchnorm":
+        return nn.BatchNorm1d
+    elif name == "graphnorm":
+        return partial(NormLayer, norm_type="groupnorm")
+    else:
+        return None
+
+
+class NormLayer(nn.Module):
+    def __init__(self, hidden_dim, norm_type):
+        super().__init__()
+        if norm_type == "batchnorm":
+            self.norm = nn.BatchNorm1d(hidden_dim)
+        elif norm_type == "layernorm":
+            self.norm = nn.LayerNorm(hidden_dim)
+        elif norm_type == "graphnorm":
+            self.norm = norm_type
+            self.weight = nn.Parameter(torch.ones(hidden_dim))
+            self.bias = nn.Parameter(torch.zeros(hidden_dim))
+
+            self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
+        else:
+            raise NotImplementedError
+
+    def forward(self, graph, x):
+        tensor = x
+        if self.norm is not None and type(self.norm) != str:
+            return self.norm(tensor)
+        elif self.norm is None:
+            return tensor
+
+        batch_list = graph.batch_num_nodes
+        batch_size = len(batch_list)
+        batch_list = torch.Tensor(batch_list).long().to(tensor.device)
+        batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
+        batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
+        mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
+        mean = mean.scatter_add_(0, batch_index, tensor)
+        mean = (mean.T / batch_list).T
+        mean = mean.repeat_interleave(batch_list, dim=0)
+
+        sub = tensor - mean * self.mean_scale
+
+        std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
+        std = std.scatter_add_(0, batch_index, sub.pow(2))
+        std = ((std.T / batch_list).T + 1e-6).sqrt()
+        std = std.repeat_interleave(batch_list, dim=0)
+        return self.weight * sub / std + self.bias
diff --git a/trainer/utils/wget_parser.py b/trainer/utils/wget_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..3ffbcffda1c5e9176cf1e5219b8c1e0ed1574055
--- /dev/null
+++ b/trainer/utils/wget_parser.py
@@ -0,0 +1,820 @@
+import argparse
+import json
+import os
+
+import xxhash
+import tqdm
+import logging
+import networkx as nx
+from tqdm import tqdm
+import time
+import datetime
+
+
+valid_node_type = ['file', 'process_memory', 'task', 'mmaped_file', 'path', 'socket', 'address', 'link']
+CONSOLE_ARGUMENTS = None
+
+
+def hashgen(l):
+    """Generate a single hash value from a list. @l is a list of
+    string values, which can be properties of a node/edge. This
+    function returns a single hashed integer value."""
+    hasher = xxhash.xxh64()
+    for e in l:
+        hasher.update(e)
+    return hasher.intdigest()
+
+
+def parse_nodes(json_string, node_map):
+    """Parse a CamFlow JSON string that may contain nodes ("activity" or "entity").
+    Parsed nodes populate @node_map, which is a dictionary that maps the node's UID,
+    which is assigned by CamFlow to uniquely identify a node object, to a hashed
+    value (in str) which represents the 'type' of the node. """
+    json_object = None
+    try:
+        # use "ignore" if non-decodeable exists in the @json_string
+        json_object = json.loads(json_string)
+    except Exception as e:
+        print("Exception ({}) occurred when parsing a node in JSON:".format(e))
+        print(json_string)
+        exit(1)
+    if "activity" in json_object:
+        activity = json_object["activity"]
+        for uid in activity:
+            if not uid in node_map:  # only parse unseen nodes
+                if "prov:type" not in activity[uid]:
+                    # a node must have a type.
+                    # record this issue if logging is turned on
+                    if CONSOLE_ARGUMENTS.verbose:
+                        logging.debug("skipping a problematic activity node with no 'prov:type': {}".format(uid))
+                else:
+                    node_map[uid] = activity[uid]["prov:type"]
+
+    if "entity" in json_object:
+        entity = json_object["entity"]
+        for uid in entity:
+            if not uid in node_map:
+                if "prov:type" not in entity[uid]:
+                    if CONSOLE_ARGUMENTS.verbose:
+                        logging.debug("skipping a problematic entity node with no 'prov:type': {}".format(uid))
+                else:
+                    node_map[uid] = entity[uid]["prov:type"]
+
+
+def parse_all_nodes(filename, node_map):
+    """Parse all nodes in CamFlow data. @filename is the file path of
+    the CamFlow data to parse. @node_map contains the mappings of all
+    CamFlow nodes to their hashed attributes. """
+    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing nodes in CamFlow data from {}'.format(filename)
+    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+    with open(filename, 'r') as f:
+        # each line in CamFlow data could contain multiple
+        # provenance nodes, we call @parse_nodes routine.
+        for line in f:
+            pb.update()  # for progress tracking
+            parse_nodes(line, node_map)
+    f.close()
+    pb.close()
+
+
+def parse_all_edges(inputfile, outputfile, node_map, noencode):
+    """Parse all edges (including their timestamp) from CamFlow data file @inputfile
+    to an @outputfile. Before this function is called, parse_all_nodes should be called
+    to populate the @node_map for all nodes in the CamFlow file. If @noencode is set,
+    we do not hash the nodes' original UUIDs generated by CamFlow to integers. This
+    function returns the total number of valid edge parsed from CamFlow dataset.
+
+    The output edgelist has the following format for each line, if -s is not set:
+        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>
+    If -s is set, each line would look like:
+        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>:<timestamp_stats>"""
+    total_edges = 0
+    smallest_timestamp = None
+    # scan through the entire file to find the smallest timestamp from all the edges.
+    # this step is only needed if we need to add some statistical information.
+    if CONSOLE_ARGUMENTS.stats:
+        description = '\x1b[6;30;42m[STATUS]\x1b[0m Scanning edges in CamFlow data from {}'.format(inputfile)
+        pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+        with open(inputfile, 'r') as f:
+            for line in f:
+                pb.update()
+                json_object = json.loads(line)
+
+                if "used" in json_object:
+                    used = json_object["used"]
+                    for uid in used:
+                        if "prov:type" not in used[uid]:
+                            continue
+                        if "cf:date" not in used[uid]:
+                            continue
+                        if "prov:entity" not in used[uid]:
+                            continue
+                        if "prov:activity" not in used[uid]:
+                            continue
+                        srcUUID = used[uid]["prov:entity"]
+                        dstUUID = used[uid]["prov:activity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = used[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasGeneratedBy" in json_object:
+                    wasGeneratedBy = json_object["wasGeneratedBy"]
+                    for uid in wasGeneratedBy:
+                        if "prov:type" not in wasGeneratedBy[uid]:
+                            continue
+                        if "cf:date" not in wasGeneratedBy[uid]:
+                            continue
+                        if "prov:entity" not in wasGeneratedBy[uid]:
+                            continue
+                        if "prov:activity" not in wasGeneratedBy[uid]:
+                            continue
+                        srcUUID = wasGeneratedBy[uid]["prov:activity"]
+                        dstUUID = wasGeneratedBy[uid]["prov:entity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasGeneratedBy[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasInformedBy" in json_object:
+                    wasInformedBy = json_object["wasInformedBy"]
+                    for uid in wasInformedBy:
+                        if "prov:type" not in wasInformedBy[uid]:
+                            continue
+                        if "cf:date" not in wasInformedBy[uid]:
+                            continue
+                        if "prov:informant" not in wasInformedBy[uid]:
+                            continue
+                        if "prov:informed" not in wasInformedBy[uid]:
+                            continue
+                        srcUUID = wasInformedBy[uid]["prov:informant"]
+                        dstUUID = wasInformedBy[uid]["prov:informed"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasInformedBy[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasDerivedFrom" in json_object:
+                    wasDerivedFrom = json_object["wasDerivedFrom"]
+                    for uid in wasDerivedFrom:
+                        if "prov:type" not in wasDerivedFrom[uid]:
+                            continue
+                        if "cf:date" not in wasDerivedFrom[uid]:
+                            continue
+                        if "prov:usedEntity" not in wasDerivedFrom[uid]:
+                            continue
+                        if "prov:generatedEntity" not in wasDerivedFrom[uid]:
+                            continue
+                        srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
+                        dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasDerivedFrom[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasAssociatedWith" in json_object:
+                    wasAssociatedWith = json_object["wasAssociatedWith"]
+                    for uid in wasAssociatedWith:
+                        if "prov:type" not in wasAssociatedWith[uid]:
+                            continue
+                        if "cf:date" not in wasAssociatedWith[uid]:
+                            continue
+                        if "prov:agent" not in wasAssociatedWith[uid]:
+                            continue
+                        if "prov:activity" not in wasAssociatedWith[uid]:
+                            continue
+                        srcUUID = wasAssociatedWith[uid]["prov:agent"]
+                        dstUUID = wasAssociatedWith[uid]["prov:activity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasAssociatedWith[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+        f.close()
+        pb.close()
+
+    # we will go through the CamFlow data (again) and output edgelist to a file
+    output = open(outputfile, "w+")
+    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in CamFlow data from {}'.format(inputfile)
+    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+    with open(inputfile, 'r') as f:
+        for line in f:
+            pb.update()
+            json_object = json.loads(line)
+
+            if "used" in json_object:
+                used = json_object["used"]
+                for uid in used:
+                    if "prov:type" not in used[uid]:
+                        # an edge must have a type; if not,
+                        # we will have to skip the edge. Log
+                        # this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "used"
+                    # cf:id is used as logical timestamp to order edges
+                    if "cf:id" not in used[uid]:
+                        # an edge must have a logical timestamp;
+                        # if not we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = used[uid]["cf:id"]
+                    if "prov:entity" not in used[uid]:
+                        # an edge's source node must exist;
+                        # if not, we will have to skip the
+                        # edge. Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record without source UUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in used[uid]:
+                        # an edge's destination node must exist;
+                        # if not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record without destination UUID: {}".format(used[uid]["prov:type"],
+                                                                                            uid))
+                        continue
+                    srcUUID = used[uid]["prov:entity"]
+                    dstUUID = used[uid]["prov:activity"]
+                    # both source and destination node must
+                    # exist in @node_map; if not, we will
+                    # have to skip the edge. Log this issue
+                    # if verbose is set.
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record with an unseen srcUUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record with an unseen dstUUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in used[uid]:
+                        # an edge must have a timestamp; if
+                        # not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        # we only record @adjusted_ts if we need
+                        # to record stats of CamFlow dataset.
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = used[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in used[uid]:
+                        # an edge must have a jiffies timestamp; if
+                        # not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        # we only record @jiffies if
+                        # the option is set
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = used[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp))
+
+            if "wasGeneratedBy" in json_object:
+                wasGeneratedBy = json_object["wasGeneratedBy"]
+                for uid in wasGeneratedBy:
+                    if "prov:type" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasGeneratedBy"
+                    if "cf:id" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasGeneratedBy[uid]["cf:id"]
+                    if "prov:entity" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record without source UUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record without destination UUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasGeneratedBy[uid]["prov:activity"]
+                    dstUUID = wasGeneratedBy[uid]["prov:entity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record with an unseen srcUUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record with an unsen dstUUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasGeneratedBy[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasGeneratedBy[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasInformedBy" in json_object:
+                wasInformedBy = json_object["wasInformedBy"]
+                for uid in wasInformedBy:
+                    if "prov:type" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasInformedBy"
+                    if "cf:id" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasInformedBy[uid]["cf:id"]
+                    if "prov:informant" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record without source UUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    if "prov:informed" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record without destination UUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasInformedBy[uid]["prov:informant"]
+                    dstUUID = wasInformedBy[uid]["prov:informed"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record with an unseen srcUUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record with an unseen dstUUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasInformedBy[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasInformedBy[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasDerivedFrom" in json_object:
+                wasDerivedFrom = json_object["wasDerivedFrom"]
+                for uid in wasDerivedFrom:
+                    if "prov:type" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasDerivedFrom"
+                    if "cf:id" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without logical timestamp: {}".format(uid))
+                            continue
+                    else:
+                        timestamp = wasDerivedFrom[uid]["cf:id"]
+                    if "prov:usedEntity" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record without source UUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    if "prov:generatedEntity" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record without destination UUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
+                    dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen srcUUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen dstUUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasDerivedFrom[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasDerivedFrom[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasAssociatedWith" in json_object:
+                wasAssociatedWith = json_object["wasAssociatedWith"]
+                for uid in wasAssociatedWith:
+                    if "prov:type" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasAssociatedWith"
+                    if "cf:id" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasAssociatedWith[uid]["cf:id"]
+                    if "prov:agent" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record without source UUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record without destination UUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasAssociatedWith[uid]["prov:agent"]
+                    dstUUID = wasAssociatedWith[uid]["prov:activity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen srcUUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen dstUUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasAssociatedWith[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasAssociatedWith[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+    f.close()
+    output.close()
+    pb.close()
+    return total_edges
+
+def read_single_graph(file_name, threshold):
+    graph = []
+    edge_cnt = 0
+    with open(file_name, 'r') as f:
+        for line in f:
+            try:
+                edge = line.strip().split("\t")
+                new_edge = [edge[0], edge[1]]
+                attributes = edge[2].strip().split(":")
+                source_node_type = attributes[0]
+                destination_node_type = attributes[1]
+                edge_type = attributes[2]
+                edge_order = attributes[3]
+
+                new_edge.append(source_node_type)
+                new_edge.append(destination_node_type)
+                new_edge.append(edge_type)
+                new_edge.append(edge_order)
+                graph.append(new_edge)
+                edge_cnt += 1
+            except:
+                print("{}".format(line))
+    f.close()
+    graph.sort(key=lambda e: e[5])
+    if len(graph) <= threshold:
+        return graph
+    else:
+        return graph[:threshold]
+
+
+def process_graph(name, threshold):
+    graph = read_single_graph(name, threshold)
+    result_graph = nx.DiGraph()
+    cnt = 0
+    for num, edge in enumerate(graph):
+        cnt += 1
+        src, dst, src_type, dst_type, edge_type = edge[:5]
+        if True:# src_type in valid_node_type and dst_type in valid_node_type:
+            if not result_graph.has_node(src):
+                result_graph.add_node(src, type=src_type)
+            if not result_graph.has_node(dst):
+                result_graph.add_node(dst, type=dst_type)
+            if not result_graph.has_edge(src, dst):
+                result_graph.add_edge(src, dst, type=edge_type)
+                if bidirection:
+                    result_graph.add_edge(dst, src, type='reverse_{}'.format(edge_type))
+    return cnt, result_graph
+
+
+node_type_list = []
+edge_type_list = []
+node_type_dict = {}
+edge_type_dict = {}
+
+
+def format_graph(g, name):
+    new_g = nx.DiGraph()
+    node_map = {}
+    node_cnt = 0
+    for n in g.nodes:
+        node_map[n] = node_cnt
+        new_g.add_node(node_cnt, type=g.nodes[n]['type'])
+        node_cnt += 1
+    for e in g.edges:
+        new_g.add_edge(node_map[e[0]], node_map[e[1]], type=g.edges[e]['type'])
+    for n in new_g.nodes:
+        node_type = new_g.nodes[n]['type']
+        if not node_type in node_type_dict:
+            node_type_list.append(node_type)
+            node_type_dict[node_type] = 1
+        else:
+            node_type_dict[node_type] += 1
+    for e in new_g.edges:
+        edge_type = new_g.edges[e]['type']
+        if not edge_type in edge_type_dict:
+            edge_type_list.append(edge_type)
+            edge_type_dict[edge_type] = 1
+        else:
+            edge_type_dict[edge_type] += 1
+    for n in new_g.nodes:
+        new_g.nodes[n]['type'] = node_type_list.index(new_g.nodes[n]['type'])
+    for e in new_g.edges:
+        new_g.edges[e]['type'] = edge_type_list.index(new_g.edges[e]['type'])
+    with open('{}.json'.format(name), 'w', encoding='utf-8') as f:
+        json.dump(nx.node_link_data(new_g), f)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Convert CamFlow JSON to Unicorn edgelist')
+    args = parser.parse_args()
+    args.stats = False
+    args.verbose = False
+    args.jiffies = False
+    args.input = '../data/wget/raw/'
+    args.output = '../data/wget/processed/'
+    args.final_output = '../data/wget/final/'
+    args.noencode = False
+    if not os.path.exists(args.input):
+        os.mkdir(args.input)
+    if not os.path.exists(args.output):
+        os.mkdir(args.output)
+    if not os.path.exists(args.final_output):
+        os.mkdir(args.final_output)
+    CONSOLE_ARGUMENTS = args
+
+    if args.verbose:
+        logging.basicConfig(filename=args.log, level=logging.DEBUG)
+    cnt = 0
+    for fname in os.listdir(args.input):
+        cnt += 1
+        node_map = dict()
+        parse_all_nodes(args.input + '/{}'.format(fname), node_map)
+        total_edges = parse_all_edges(args.input + '/{}'.format(fname), args.output + '/{}.log'.format(cnt), node_map,
+                                      args.noencode)
+        if args.stats:
+            total_nodes = len(node_map)
+            stats = open(args.stats_file + '/{}.log'.format(cnt), "a+")
+            stats.write("{},{},{}\n".format(args.input + '/{}'.format(fname), total_nodes, total_edges))
+
+    bidirection = False
+    threshold = 10000000 # infinity
+    interaction_dict = []
+    graph_cnt = 0
+    result_graphs = []
+    input = args.output
+    base = args.final_output
+
+    line_cnt = 0
+    for i in tqdm(range(cnt)):
+        single_cnt, result_graph = process_graph('{}{}.log'.format(input, i + 1), threshold)
+        format_graph(result_graph, '{}{}'.format(base, i))
+        line_cnt += single_cnt
+
+    print(line_cnt // 150)
+    print(len(node_type_list))
+    print(node_type_dict)
+    print(len(edge_type_list))
+    print(edge_type_dict)
+
diff --git a/utils/__pycache__/config.cpython-310.pyc b/utils/__pycache__/config.cpython-310.pyc
deleted file mode 100644
index ca2a48894d4871479b958f92c0ae288f5c3e6cad..0000000000000000000000000000000000000000
Binary files a/utils/__pycache__/config.cpython-310.pyc and /dev/null differ
diff --git a/utils/__pycache__/config.cpython-311.pyc b/utils/__pycache__/config.cpython-311.pyc
old mode 100644
new mode 100755
index 0a2ca013c4c0826c94b817a15f5aeaadf3fbf021..55cf6b81b72fe73af01f3176b76fd0081827f1e8
Binary files a/utils/__pycache__/config.cpython-311.pyc and b/utils/__pycache__/config.cpython-311.pyc differ
diff --git a/utils/__pycache__/configJupyter.cpython-311.pyc b/utils/__pycache__/configJupyter.cpython-311.pyc
deleted file mode 100644
index 470b992f5214101b8dcf54596f455e87c14e0d9b..0000000000000000000000000000000000000000
Binary files a/utils/__pycache__/configJupyter.cpython-311.pyc and /dev/null differ
diff --git a/utils/__pycache__/dataloader.cpython-310.pyc b/utils/__pycache__/dataloader.cpython-310.pyc
deleted file mode 100644
index d87ef800e3349966c54f8aff0158c342c15ce957..0000000000000000000000000000000000000000
Binary files a/utils/__pycache__/dataloader.cpython-310.pyc and /dev/null differ
diff --git a/utils/__pycache__/dataloader.cpython-311.pyc b/utils/__pycache__/dataloader.cpython-311.pyc
deleted file mode 100644
index 435191a8fbab8cf94e7ec9ff76aa01f782a87157..0000000000000000000000000000000000000000
Binary files a/utils/__pycache__/dataloader.cpython-311.pyc and /dev/null differ
diff --git a/utils/__pycache__/loaddata.cpython-311.pyc b/utils/__pycache__/loaddata.cpython-311.pyc
old mode 100644
new mode 100755
index 516f605481c393aef9331a111ab596d8f327e0ee..6dadf00a8f9a10b0fd4c0f82c44e99ec4cd3ffa0
Binary files a/utils/__pycache__/loaddata.cpython-311.pyc and b/utils/__pycache__/loaddata.cpython-311.pyc differ
diff --git a/utils/__pycache__/poolers.cpython-310.pyc b/utils/__pycache__/poolers.cpython-310.pyc
deleted file mode 100644
index 441c02bbdd690fbfccc5acf786b9ded0f067bd8d..0000000000000000000000000000000000000000
Binary files a/utils/__pycache__/poolers.cpython-310.pyc and /dev/null differ
diff --git a/utils/__pycache__/poolers.cpython-311.pyc b/utils/__pycache__/poolers.cpython-311.pyc
old mode 100644
new mode 100755
index 4464f0cddb9bf9523f2ee1a43193ac41ebdb2f73..d761b00f03cf59c7afe8815f579b18fed2f7dd8f
Binary files a/utils/__pycache__/poolers.cpython-311.pyc and b/utils/__pycache__/poolers.cpython-311.pyc differ
diff --git a/utils/__pycache__/utils.cpython-310.pyc b/utils/__pycache__/utils.cpython-310.pyc
deleted file mode 100644
index 10658e8f670d580962f3f3982605978f526561a7..0000000000000000000000000000000000000000
Binary files a/utils/__pycache__/utils.cpython-310.pyc and /dev/null differ
diff --git a/utils/__pycache__/utils.cpython-311.pyc b/utils/__pycache__/utils.cpython-311.pyc
old mode 100644
new mode 100755
index 16c0ff7b920416d5844cb0abf8995663dd0bf741..f2d041d338d864c03c2baf1f746aaa8bba806202
Binary files a/utils/__pycache__/utils.cpython-311.pyc and b/utils/__pycache__/utils.cpython-311.pyc differ
diff --git a/utils/config.py b/utils/config.py
old mode 100644
new mode 100755
index 75b33c7e63bcbcab20114e07ab7075922c59417b..ee14920f4b2e36dfcd0d7fddca100572837ade20
--- a/utils/config.py
+++ b/utils/config.py
@@ -1,23 +1,13 @@
-import argparse
-
-
 def build_args():
-    parser = argparse.ArgumentParser(description="MAGIC")
-    parser.add_argument("--device", type=int, default=-1)
-    parser.add_argument("--lr", type=float, default=0.0001,
-                        help="learning rate")
-    parser.add_argument("--weight_decay", type=float, default=5e-4,
-                        help="weight decay")
-    parser.add_argument("--negative_slope", type=float, default=0.2,
-                        help="the negative slope of leaky relu for GAT")
-    parser.add_argument("--mask_rate", type=float, default=0.5)
-    parser.add_argument("--alpha_l", type=float, default=3, help="`pow`inddex for `sce` loss")
-    parser.add_argument("--optimizer", type=str, default="adam")
-    parser.add_argument("--loss_fn", type=str, default='sce')
-    parser.add_argument("--pooling", type=str, default="mean")
-    parser.add_argument("--run_id", type=str, default="fedgraphnn_cs_ch")
-    parser.add_argument("--cf", type=str, default="fedml_config.yaml")
-    parser.add_argument("--rank", type=int, default=0)
-    parser.add_argument("--role", type=str, default="server")
-    args = parser.parse_args()
-    return args
\ No newline at end of file
+    args = {}
+    args["dataset"] ="wget"
+    args["device"]=-1
+    args["lr"]=0.001
+    args["weight_decay"]=5e-4
+    args["negative_slope"]=0.2
+    args["mask_rate"]=0.5
+    args["alpha_l"]=3
+    args["optimizer"]="adam"
+    args["loss_fn"]='sce'
+    args["pooling"] = "mean"
+    return args
diff --git a/utils/configJupyter.py b/utils/configJupyter.py
deleted file mode 100644
index 269779745a9c835b757ea725d201e11e5ee3e87b..0000000000000000000000000000000000000000
--- a/utils/configJupyter.py
+++ /dev/null
@@ -1,14 +0,0 @@
-
-def build_args():
-
-    args = {}
-    args['device'] = -1 
-    args['lr'] = 0.01
-    args['weight_decay'] = 5e-4
-    args['negative_slope'] = 0.2
-    args['mask_rate'] = 0.5
-    args['alpha_l'] = 3
-    args['optimizer'] = 'adam'
-    args['loss_fn'] = "sce"
-    args['pooling'] = "mean"
-    return args
\ No newline at end of file
diff --git a/utils/dataloader.py b/utils/dataloader.py
deleted file mode 100644
index 992c6fcacaee228330db91e920c03c6d4f2fc8ad..0000000000000000000000000000000000000000
--- a/utils/dataloader.py
+++ /dev/null
@@ -1,331 +0,0 @@
-import os
-import random
-import networkx as nx
-import dgl
-import torch
-import pickle as pkl
-import json
-import logging
-import numpy as np
-
-path_dataset = 'D:/PFE DATASETS/'
-
-def darpa_split(name):
-    metadata = load_metadata(name)
-    n_train = metadata['n_train']
-    train_dataset = range(n_train)
-    train_labels = [0]* n_train
-    
-    
-    return (
-        train_dataset,
-        train_labels,
-        [],
-        [],
-        [],
-        []
-    )
-       
-
-def create_random_split(name, snapshots):
-    dataset = load_data(name)
-    # Random 80/10/10 split as suggested 
-    
-
-    all_idxs = list(range(len(dataset)))
-    random.shuffle(all_idxs)
-
-    train_dataset = dataset['train_index']
-    train_labels = []
-    for id in train_dataset:
-        train_labels.append(dataset['labels'][id])
-
-    val_dataset = dataset['validation_index']
-    val_labels = []
-    for id in val_dataset:
-        val_labels.append(dataset['labels'][id])
-
-    test_dataset = dataset['test_index']
-    test_labels = []
-    for id in test_dataset:
-        test_labels.append(dataset['labels'][id])
-
-
-    return (
-        train_dataset,
-        train_labels,
-        val_dataset,
-        val_labels,
-        test_dataset,
-        test_labels,
-    )
-
-
-
-def partition_data_by_sample_size(
-     client_number, name, snapshots
-):
-    if (name in ['wget', 'streamspot', 'SC2', 'Unicorn-Cadets', 'wget-long', 'clearscope-e3']):
-        (
-            train_dataset,
-            train_labels,
-            val_dataset,
-            val_labels,
-            test_dataset,
-            test_labels,
-        ) = create_random_split(name, snapshots)
-    else:
-        (
-            train_dataset,
-            train_labels,
-            val_dataset,
-            val_labels,
-            test_dataset,
-            test_labels,
-        ) = darpa_split(name)
-        
-    num_train_samples = len(train_dataset)
-    num_val_samples = len(val_dataset)
-    num_test_samples = len(test_dataset)
-
-    train_idxs = list(range(num_train_samples))
-    val_idxs = list(range(num_val_samples))
-    test_idxs = list(range(num_test_samples))
-
-    random.shuffle(train_idxs)
-    random.shuffle(val_idxs)
-    random.shuffle(test_idxs)
-
-    partition_dicts = [None] * client_number
-
-    
-    clients_idxs_train = np.array_split(train_idxs, client_number)
-    clients_idxs_val = np.array_split(val_idxs, client_number)
-    clients_idxs_test = np.array_split(test_idxs, client_number)
-    
-    labels_of_all_clients = []
-    for client in range(client_number):
-        client_train_idxs = clients_idxs_train[client]
-        client_val_idxs = clients_idxs_val[client]
-        client_test_idxs = clients_idxs_test[client]
-
-        train_dataset_client = [
-            train_dataset[idx] for idx in client_train_idxs
-        ]
-        train_labels_client = [train_labels[idx] for idx in client_train_idxs]
-        labels_of_all_clients.append(train_labels_client)
-
-        val_dataset_client = [val_dataset[idx] for idx in client_val_idxs]
-        val_labels_client = [val_labels[idx] for idx in client_val_idxs]
-
-        test_dataset_client = [test_dataset[idx] for idx in client_test_idxs]
-        test_labels_client = [test_labels[idx] for idx in client_test_idxs]
-
-
-        partition_dict = {
-            "train": train_dataset_client,
-            "val": val_dataset_client,
-            "test": test_dataset_client,
-        }
-
-        partition_dicts[client] = partition_dict
-    global_data_dict = {
-        "train": train_dataset,
-        "val": val_dataset,
-        "test": test_dataset,
-    }
-
-    return global_data_dict, partition_dicts
-
-def load_partition_data(
-    client_number,
-    name,
-    snapshots,
-    global_test=True,
-):
-    global_data_dict, partition_dicts = partition_data_by_sample_size(
-        client_number, name, snapshots
-    )
-
-    data_local_num_dict = dict()
-    train_data_local_dict = dict()
-    val_data_local_dict = dict()
-    test_data_local_dict = dict()
-
-  
-
-    # IT IS VERY IMPORTANT THAT THE BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE.
-    train_data_global =  global_data_dict["train"]
-    val_data_global = global_data_dict["val"]
-    test_data_global = global_data_dict["test"]
-    train_data_num = len(global_data_dict["train"])
-    val_data_num = len(global_data_dict["val"])
-    test_data_num = len(global_data_dict["test"])
-
-    for client in range(client_number):
-        train_dataset_client = partition_dicts[client]["train"]
-        val_dataset_client = partition_dicts[client]["val"]
-        test_dataset_client = partition_dicts[client]["test"]
-
-        data_local_num_dict[client] = len(train_dataset_client)
-        train_data_local_dict[client] = train_dataset_client,
- 
-        val_data_local_dict[client] = val_dataset_client
-  
-        test_data_local_dict[client] = (
-            test_data_global
-            if global_test
-            else test_dataset_client
-                
-        )
-
-        logging.info(
-            "Client idx = {}, local sample number = {}".format(
-                client, len(train_dataset_client)
-            )
-        )
-
-    return (
-        train_data_num,
-        val_data_num,
-        test_data_num,
-        train_data_global,
-        val_data_global,
-        test_data_global,
-        data_local_num_dict,
-        train_data_local_dict,
-        val_data_local_dict,
-        test_data_local_dict,
-    )
-
-
-
-
-
-
-
-def preload_entity_level_dataset(name):
-    path = path_dataset + name
-    if os.path.exists(path + '/metadata.json'):
-        pass
-    else:
-    
-        malicious = pkl.load(open(path + '/malicious.pkl', 'rb'))
-
-        n_train = len(os.listdir(path + '/train'))
-        n_test = len(os.listdir(path + '/test'))
-            
-        g = pkl.load(open(path + '/train/graph0/graph0.pkl', 'rb'))
-
-        node_feature_dim = len(g.ndata['attr'][0])
-        edge_feature_dim = len(g.edata['attr'][0])
-
-        metadata = {
-            'node_feature_dim': node_feature_dim,
-            'edge_feature_dim': edge_feature_dim,
-            'malicious': malicious,
-            'n_train': n_train,
-            'n_test': n_test
-        }
-        with open(path + '/metadata.json', 'w', encoding='utf-8') as f:
-            json.dump(metadata, f)
-
-
-
-def load_metadata(name):
-    preload_entity_level_dataset(name)
-    with open( path_dataset + name + '/metadata.json', 'r', encoding='utf-8') as f:
-        metadata = json.load(f)
-    return metadata
-
-
-def load_entity_level_dataset(name, t, n, snapshot, device):
-    preload_entity_level_dataset(name)
-    graphs = []
-    for i in range(snapshot):
-        with open(path_dataset + name + '/' + t + '/graph{}/graph{}.pkl'.format(n, str(i)), 'rb') as f:
-            graphs.append(pkl.load(f).to(device))
-    return graphs
-
-
-def get_labels(name):
-    if (name=="wget" ):
-        return [1] * 25 + [0] * 125
-    elif (name=="streamspot"):
-        return [0] * 300 + [1] * 100 + [0] * 200
-    elif (name == 'SC2'):
-        return [0] * 125 + [1] * 25 
-    elif (name == 'Unicorn-Cadets'):
-        return [0] * 109 + [1] * 3
-    elif (name == 'wget-long'):
-        return [0] * 100 + [1] * 5
-    elif (name == 'clearscope-e3'):
-        return [0] * 44 + [1] * 50
-    
-def load_data(name):
-    if name == "wget":
-        n, n_dim, e_dim = 150, 14, 4
-        full_dataset_index = list(range(n))
-        train_dataset = list(range(50, 150))
-        validation_dataset = list(range(50))
-        test_dataset = list(range(50))
-    elif name == "streamspot":
-        n, n_dim, e_dim = 600, 8, 26
-        full_dataset_index = list(range(n))
-        train_dataset = list(range(300)) 
-        validation_dataset = list(range(300, 350)) + list(range(500,550))
-        test_dataset = list(range(300, 400)) + list(range(400,500))+ list(range(500,600))
-    elif name == 'SC2':
-        n_dim = len(pkl.load(open(path_dataset + 'SC2/node.pkl', 'rb')).keys())
-        e_dim = len(pkl.load(open(path_dataset + 'SC2/edge.pkl', 'rb')).keys())
-        n, full_dataset_index = 150, list(range(150))
-        train_dataset = list(range(100))
-        validation_dataset = list(range(100, 150))
-        test_dataset = list(range(100, 150))
-    elif name in ['Unicorn-Cadets', 'wget-long', 'clearscope-e3']:
-        n_dim = len(pkl.load(open(path_dataset + '{}/node.pkl'.format(name), 'rb')).keys())
-        e_dim = len(pkl.load(open(path_dataset + '{}/edge.pkl'.format(name), 'rb')).keys())
-        if name == 'Unicorn-Cadets':
-            n, train_dataset = 112, list(range(70))
-        elif name == 'wget-long':
-            n, train_dataset = 105, list(range(70))
-        else:
-            n, train_dataset = 94, list(range(30))
-        full_dataset_index = list(range(n))
-        validation_dataset = list(range(train_dataset[-1], n))
-        test_dataset = validation_dataset
-    return {'dataset': full_dataset_index,
-            'train_index': train_dataset,
-            'test_index': test_dataset,
-            'validation_index': validation_dataset,
-            'full_index': full_dataset_index,
-            'n_feat': n_dim,
-            'e_feat': e_dim,
-            'labels': get_labels(name)}
-            
-
-            
-def load_graph(id, name ,device):
-    graphs = []
-
-    if (name == "wget"):
-        path = path_dataset + 'wget/cache/' + 'graph{}'.format(str(id))
-    elif (name == "streamspot"):
-        path = path_dataset + 'streamspot/cache/' + 'graph{}'.format(str(id))
-    elif (name == "SC2"):
-        if (id < 125): path = path_dataset + 'SC2/cache/benign/' + 'graph{}'.format(str(id))
-        else: path = path_dataset + 'SC2/cache/attack/' + 'graph{}'.format(str(id - 125))
-    elif (name == 'Unicorn-Cadets'):
-        if (id < 109): path = path_dataset + 'Unicorn-Cadets/cache/benign/' + 'graph{}'.format(str(id))
-        else: path = path_dataset + 'Unicorn-Cadets/cache/attack/' + 'graph{}'.format(str(id - 109))
-    elif (name == 'wget-long'):
-        if (id < 100): path = path_dataset + 'wget-long/cache/benign/' + 'graph{}'.format(str(id))
-        else: path = path_dataset + 'wget-long/cache/attack/' + 'graph{}'.format(str(id - 100))
-    elif (name == 'clearscope-e3'):
-        if (id < 44): path = path_dataset + 'clearscope-e3/cache/benign/' + 'graph{}'.format(str(id))
-        else: path = path_dataset + 'clearscope-e3/cache/attack/' + 'graph{}'.format(str(id - 44))
-
-    for fname in os.listdir(path):
-        graphs.append(pkl.load(open(path + '/' + fname, 'rb')).to(device))
-
-    return graphs
\ No newline at end of file
diff --git a/utils/loaddata.py b/utils/loaddata.py
new file mode 100755
index 0000000000000000000000000000000000000000..ca48e0fd60aa69712f48c06f8083b37f679cf797
--- /dev/null
+++ b/utils/loaddata.py
@@ -0,0 +1,207 @@
+import pickle as pkl
+import time
+import torch.nn.functional as F
+import dgl
+import networkx as nx
+import json
+from tqdm import tqdm
+import os
+
+
+
+class StreamspotDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(StreamspotDataset, self).__init__(name=name)
+        if name == 'streamspot':
+            path = './data/streamspot'
+            num_graphs = 600
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                if 300 <= idx <= 399:
+                    self.labels.append(1)
+                else:
+                    self.labels.append(0)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+
+class WgetDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(WgetDataset, self).__init__(name=name)
+        if name == 'wget':
+            pathattack = '/data/wget/finalattack'
+            pathbenin = 'data/wget/finalbenin'
+            num_graphs_benin = 125
+            num_graphs_attack = 25
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs_benin)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathbenin, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                self.labels.append(0)
+                
+            for i in tqdm(range(num_graphs_attack)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(pathattack, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                self.labels.append(1)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+
+def load_rawdata(name):
+    if name == 'streamspot':
+        path = './data/streamspot'
+        if os.path.exists(path + '/graphs.pkl'):
+            print('Loading processed {} dataset...'.format(name))
+            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
+        else:
+            raw_data = StreamspotDataset(name)
+            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
+    elif name == 'wget':
+        path = './data/wget'
+        if os.path.exists(path + '/graphs.pkl'):
+            print('Loading processed {} dataset...'.format(name))
+            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
+        else:
+            raw_data = WgetDataset(name)
+            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
+    else:
+        raise NotImplementedError
+    return raw_data
+
+
+def load_batch_level_dataset(dataset_name):
+    dataset = load_rawdata(dataset_name)
+    graph, _ = dataset[0]
+    node_feature_dim = 0
+    for g, _ in dataset:
+        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
+    edge_feature_dim = 0
+    for g, _ in dataset:
+        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
+    node_feature_dim += 1
+    edge_feature_dim += 1
+    full_dataset = [i for i in range(len(dataset))]
+    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
+    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
+
+    return {'dataset': dataset,
+            'train_index': train_dataset,
+            'full_index': full_dataset,
+            'n_feat': node_feature_dim,
+            'e_feat': edge_feature_dim}
+
+
+def transform_graph(g, node_feature_dim, edge_feature_dim):
+    new_g = g.clone()
+    new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float()
+    new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float()
+    return new_g
+
+
+def preload_entity_level_dataset(path):
+    path = './data/' + path
+    if os.path.exists(path + '/metadata.json'):
+        pass
+    else:
+        print('transforming')
+        train_gs = [dgl.from_networkx(
+            nx.node_link_graph(g),
+            node_attrs=['type'],
+            edge_attrs=['type']
+        ) for g in pkl.load(open(path + '/train.pkl', 'rb'))]
+        print('transforming')
+        test_gs = [dgl.from_networkx(
+            nx.node_link_graph(g),
+            node_attrs=['type'],
+            edge_attrs=['type']
+        ) for g in pkl.load(open(path + '/test.pkl', 'rb'))]
+        malicious = pkl.load(open(path + '/malicious.pkl', 'rb'))
+
+        node_feature_dim = 0
+        for g in train_gs:
+            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
+        for g in test_gs:
+            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
+        node_feature_dim += 1
+        edge_feature_dim = 0
+        for g in train_gs:
+            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
+        for g in test_gs:
+            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
+        edge_feature_dim += 1
+        result_test_gs = []
+        for g in test_gs:
+            g = transform_graph(g, node_feature_dim, edge_feature_dim)
+            result_test_gs.append(g)
+        result_train_gs = []
+        for g in train_gs:
+            g = transform_graph(g, node_feature_dim, edge_feature_dim)
+            result_train_gs.append(g)
+        metadata = {
+            'node_feature_dim': node_feature_dim,
+            'edge_feature_dim': edge_feature_dim,
+            'malicious': malicious,
+            'n_train': len(result_train_gs),
+            'n_test': len(result_test_gs)
+        }
+        with open(path + '/metadata.json', 'w', encoding='utf-8') as f:
+            json.dump(metadata, f)
+        for i, g in enumerate(result_train_gs):
+            with open(path + '/train{}.pkl'.format(i), 'wb') as f:
+                pkl.dump(g, f)
+        for i, g in enumerate(result_test_gs):
+            with open(path + '/test{}.pkl'.format(i), 'wb') as f:
+                pkl.dump(g, f)
+
+
+def load_metadata(path):
+    preload_entity_level_dataset(path)
+    with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f:
+        metadata = json.load(f)
+    return metadata
+
+
+def load_entity_level_dataset(path, t, n):
+    preload_entity_level_dataset(path)
+    with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f:
+        data = pkl.load(f)
+    return data
diff --git a/utils/poolers.py b/utils/poolers.py
old mode 100644
new mode 100755
diff --git a/utils/streamspot_parser.py b/utils/streamspot_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..07687828ea5ecf56f9859955e599afb584826f1f
--- /dev/null
+++ b/utils/streamspot_parser.py
@@ -0,0 +1,55 @@
+import networkx as nx
+from tqdm import tqdm
+import json
+raw_path = '../data/streamspot/'
+
+NUM_GRAPHS = 600
+node_type_dict = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
+edge_type_dict = ['i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
+                  'q', 't', 'u', 'v', 'w', 'y', 'z', 'A', 'C', 'D', 'E', 'G']
+node_type_set = set(node_type_dict)
+edge_type_set = set(edge_type_dict)
+print(raw_path)
+count_graph = 0
+with open(raw_path + 'all.tsv', 'r', encoding='utf-8') as f:
+    print("Reading")
+    lines = f.readlines()
+    print("starting")
+    g = nx.DiGraph()
+    node_map = {}
+    count_node = 0
+    for line in tqdm(lines):
+        src, src_type, dst, dst_type, etype, graph_id = line.strip('\n').split('\t')
+        graph_id = int(graph_id)
+        if src_type not in node_type_set or dst_type not in node_type_set:
+            continue
+        if etype not in edge_type_set:
+            continue
+        if graph_id != count_graph:
+            count_graph += 1
+            for n in g.nodes():
+                g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
+            for e in g.edges():
+                g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
+            f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
+            json.dump(nx.node_link_data(g), f1)
+            assert graph_id == count_graph
+            g = nx.DiGraph()
+            count_node = 0
+        if src not in node_map:
+            node_map[src] = count_node
+            g.add_node(count_node, type=src_type)
+            count_node += 1
+        if dst not in node_map:
+            node_map[dst] = count_node
+            g.add_node(count_node, type=dst_type)
+            count_node += 1
+        if not g.has_edge(node_map[src], node_map[dst]):
+            g.add_edge(node_map[src], node_map[dst], type=etype)
+    count_graph += 1
+    for n in g.nodes():
+        g.nodes[n]['type'] = node_type_dict.index(g.nodes[n]['type'])
+    for e in g.edges():
+        g.edges[e]['type'] = edge_type_dict.index(g.edges[e]['type'])
+    f1 = open(raw_path + str(count_graph) + '.json', 'w', encoding='utf-8')
+    json.dump(nx.node_link_data(g), f1)
diff --git a/utils/trace_parser.py b/utils/trace_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..76a91d424a3eea0bbec87ffb635a0d98b27d72f7
--- /dev/null
+++ b/utils/trace_parser.py
@@ -0,0 +1,288 @@
+import argparse
+import json
+import os
+import random
+import re
+
+from tqdm import tqdm
+import networkx as nx
+import pickle as pkl
+
+
+node_type_dict = {}
+edge_type_dict = {}
+node_type_cnt = 0
+edge_type_cnt = 0
+
+metadata = {
+    'trace':{
+        'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'],
+        'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4']
+    },
+    'theia':{
+            'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'],
+            'test': ['ta1-theia-e3-official-6r.json.8']
+    },
+    'cadets':{
+            'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'],
+            'test': ['ta1-cadets-e3-official-2.json']
+    }
+}
+
+
+pattern_uuid = re.compile(r'uuid\":\"(.*?)\"')
+pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_type = re.compile(r'type\":\"(.*?)\"')
+pattern_time = re.compile(r'timestampNanos\":(.*?),')
+pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"')
+pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"')
+pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"')
+
+adversarial = 'None'
+
+
+def read_single_graph(dataset, malicious, path, test=False):
+    global node_type_cnt, edge_type_cnt
+    g = nx.DiGraph()
+    print('converting {} ...'.format(path))
+    path = '../data/{}/'.format(dataset) + path + '.txt'
+    f = open(path, 'r')
+    lines = []
+    for l in f.readlines():
+        split_line = l.split('\t')
+        src, src_type, dst, dst_type, edge_type, ts = split_line
+        ts = int(ts)
+        if not test:
+            if src in malicious or dst in malicious:
+                if src in malicious and src_type != 'MemoryObject':
+                    continue
+                if dst in malicious and dst_type != 'MemoryObject':
+                    continue
+
+        if src_type not in node_type_dict:
+            node_type_dict[src_type] = node_type_cnt
+            node_type_cnt += 1
+        if dst_type not in node_type_dict:
+            node_type_dict[dst_type] = node_type_cnt
+            node_type_cnt += 1
+        if edge_type not in edge_type_dict:
+            edge_type_dict[edge_type] = edge_type_cnt
+            edge_type_cnt += 1
+        if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type:
+            lines.append([dst, src, dst_type, src_type, edge_type, ts])
+        else:
+            lines.append([src, dst, src_type, dst_type, edge_type, ts])
+    lines.sort(key=lambda l: l[5])
+
+    node_map = {}
+    node_type_map = {}
+    node_cnt = 0
+    node_list = []
+    for l in lines:
+        src, dst, src_type, dst_type, edge_type = l[:5]
+        src_type_id = node_type_dict[src_type]
+        dst_type_id = node_type_dict[dst_type]
+        edge_type_id = edge_type_dict[edge_type]
+        if src not in node_map:
+            if test:
+                if adversarial == 'MFE' or adversarial == 'MCE':
+                    if src in malicious:
+                        src_type_id = node_type_dict['FILE_OBJECT_FILE']
+            if not test:
+                if adversarial == 'BFP':
+                    if src not in malicious:
+                        i = random.randint(1, 20)
+                        if i == 1:
+                            src_type_id = node_type_dict['NetFlowObject']
+            node_map[src] = node_cnt
+            g.add_node(node_cnt, type=src_type_id)
+            node_list.append(src)
+            node_type_map[src] = src_type
+            node_cnt += 1
+        if dst not in node_map:
+            if test:
+                if adversarial == 'MFE' or adversarial == 'MCE':
+                    if dst in malicious:
+                        dst_type_id = node_type_dict['FILE_OBJECT_FILE']
+            if not test:
+                if adversarial == 'BFP':
+                    if dst not in malicious:
+                        i = random.randint(1, 20)
+                        if i == 1:
+                            dst_type_id = node_type_dict['NetFlowObject']
+            node_map[dst] = node_cnt
+            g.add_node(node_cnt, type=dst_type_id)
+            node_type_map[dst] = dst_type
+            node_list.append(dst)
+            node_cnt += 1
+        if not g.has_edge(node_map[src], node_map[dst]):
+            g.add_edge(node_map[src], node_map[dst], type=edge_type_id)
+    if (adversarial == 'MSE' or adversarial == 'MCE') and test:
+        for i, node in enumerate(node_list):
+            if node in malicious:
+                while True:
+                    another_node = random.choice(node_list)
+                    if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE'])
+                        print(node, another_node)
+                        break
+                    if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ'])
+                        print(node, another_node)
+                        break
+                    if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT'])
+                        print(node, another_node)
+                        break
+                    if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject':
+                        break
+    return node_map, g
+
+
+def preprocess_dataset(dataset):
+    id_nodetype_map = {}
+    id_nodename_map = {}
+    for file in os.listdir('../data/{}/'.format(dataset)):
+        if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file:
+            print('reading {} ...'.format(file))
+            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
+            for line in tqdm(f):
+                if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue
+                if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue
+                if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue
+                if len(pattern_uuid.findall(line)) == 0: print(line)
+                uuid = pattern_uuid.findall(line)[0]
+                subject_type = pattern_type.findall(line)
+
+                if len(subject_type) < 1:
+                    if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line:
+                        subject_type = 'MemoryObject'
+                    if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line:
+                        subject_type = 'NetFlowObject'
+                    if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line:
+                        subject_type = 'UnnamedPipeObject'
+                else:
+                    subject_type = subject_type[0]
+
+                if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']:
+                    continue
+                id_nodetype_map[uuid] = subject_type
+                if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_file_name.findall(line)[0]
+                elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_process_name.findall(line)[0]
+                elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0]
+    for key in metadata[dataset]:
+        for file in metadata[dataset][key]:
+            if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'):
+                continue
+            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
+            fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8')
+            print('processing {} ...'.format(file))
+            for line in tqdm(f):
+                if 'com.bbn.tc.schema.avro.cdm18.Event' in line:
+                    edgeType = pattern_type.findall(line)[0]
+                    timestamp = pattern_time.findall(line)[0]
+                    srcId = pattern_src.findall(line)
+
+                    if len(srcId) == 0: continue
+                    srcId = srcId[0]
+                    if not srcId in id_nodetype_map:
+                        continue
+                    srcType = id_nodetype_map[srcId]
+                    dstId1 = pattern_dst1.findall(line)
+                    if len(dstId1) > 0 and dstId1[0] != 'null':
+                        dstId1 = dstId1[0]
+                        if not dstId1 in id_nodetype_map:
+                            continue
+                        dstType1 = id_nodetype_map[dstId1]
+                        this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str(
+                            dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
+                        fw.write(this_edge1)
+
+                    dstId2 = pattern_dst2.findall(line)
+                    if len(dstId2) > 0 and dstId2[0] != 'null':
+                        dstId2 = dstId2[0]
+                        if not dstId2 in id_nodetype_map.keys():
+                            continue
+                        dstType2 = id_nodetype_map[dstId2]
+                        this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str(
+                            dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
+                        fw.write(this_edge2)
+            fw.close()
+            f.close()
+    if len(id_nodename_map) != 0:
+        fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8')
+        json.dump(id_nodename_map, fw)
+    if len(id_nodetype_map) != 0:
+        fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8')
+        json.dump(id_nodetype_map, fw)
+
+
+def read_graphs(dataset):
+    malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset)
+    f = open(malicious_entities, 'r')
+    malicious_entities = set()
+    for l in f.readlines():
+        malicious_entities.add(l.lstrip().rstrip())
+
+    preprocess_dataset(dataset)
+    train_gs = []
+    for file in metadata[dataset]['train']:
+        _, train_g = read_single_graph(dataset, malicious_entities, file, False)
+        train_gs.append(train_g)
+    test_gs = []
+    test_node_map = {}
+    count_node = 0
+    for file in metadata[dataset]['test']:
+        node_map, test_g = read_single_graph(dataset, malicious_entities, file, True)
+        assert len(node_map) == test_g.number_of_nodes()
+        test_gs.append(test_g)
+        for key in node_map:
+            if key not in test_node_map:
+                test_node_map[key] = node_map[key] + count_node
+        count_node += test_g.number_of_nodes()
+
+    if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)):
+        with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f:
+            id_nodename_map = json.load(f)
+        with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f:
+            id_nodetype_map = json.load(f)
+        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
+        final_malicious_entities = []
+        malicious_names = []
+        for e in malicious_entities:
+            if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject':
+                final_malicious_entities.append(test_node_map[e])
+                if e in id_nodename_map:
+                    malicious_names.append(id_nodename_map[e])
+                    f.write('{}\t{}\n'.format(e, id_nodename_map[e]))
+                else:
+                    malicious_names.append(e)
+                    f.write('{}\t{}\n'.format(e, e))
+    else:
+        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
+        final_malicious_entities = []
+        malicious_names = []
+        for e in malicious_entities:
+            if e in test_node_map:
+                final_malicious_entities.append(test_node_map[e])
+                malicious_names.append(e)
+                f.write('{}\t{}\n'.format(e, e))
+
+    pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb'))
+    pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb'))
+    pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb'))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='CDM Parser')
+    parser.add_argument("--dataset", type=str, default="trace")
+    args = parser.parse_args()
+    if args.dataset not in ['trace', 'theia', 'cadets']:
+        raise NotImplementedError
+    read_graphs(args.dataset)
+
diff --git a/utils/utils.py b/utils/utils.py
old mode 100644
new mode 100755
diff --git a/utils/wget_parser.py b/utils/wget_parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..3ffbcffda1c5e9176cf1e5219b8c1e0ed1574055
--- /dev/null
+++ b/utils/wget_parser.py
@@ -0,0 +1,820 @@
+import argparse
+import json
+import os
+
+import xxhash
+import tqdm
+import logging
+import networkx as nx
+from tqdm import tqdm
+import time
+import datetime
+
+
+valid_node_type = ['file', 'process_memory', 'task', 'mmaped_file', 'path', 'socket', 'address', 'link']
+CONSOLE_ARGUMENTS = None
+
+
+def hashgen(l):
+    """Generate a single hash value from a list. @l is a list of
+    string values, which can be properties of a node/edge. This
+    function returns a single hashed integer value."""
+    hasher = xxhash.xxh64()
+    for e in l:
+        hasher.update(e)
+    return hasher.intdigest()
+
+
+def parse_nodes(json_string, node_map):
+    """Parse a CamFlow JSON string that may contain nodes ("activity" or "entity").
+    Parsed nodes populate @node_map, which is a dictionary that maps the node's UID,
+    which is assigned by CamFlow to uniquely identify a node object, to a hashed
+    value (in str) which represents the 'type' of the node. """
+    json_object = None
+    try:
+        # use "ignore" if non-decodeable exists in the @json_string
+        json_object = json.loads(json_string)
+    except Exception as e:
+        print("Exception ({}) occurred when parsing a node in JSON:".format(e))
+        print(json_string)
+        exit(1)
+    if "activity" in json_object:
+        activity = json_object["activity"]
+        for uid in activity:
+            if not uid in node_map:  # only parse unseen nodes
+                if "prov:type" not in activity[uid]:
+                    # a node must have a type.
+                    # record this issue if logging is turned on
+                    if CONSOLE_ARGUMENTS.verbose:
+                        logging.debug("skipping a problematic activity node with no 'prov:type': {}".format(uid))
+                else:
+                    node_map[uid] = activity[uid]["prov:type"]
+
+    if "entity" in json_object:
+        entity = json_object["entity"]
+        for uid in entity:
+            if not uid in node_map:
+                if "prov:type" not in entity[uid]:
+                    if CONSOLE_ARGUMENTS.verbose:
+                        logging.debug("skipping a problematic entity node with no 'prov:type': {}".format(uid))
+                else:
+                    node_map[uid] = entity[uid]["prov:type"]
+
+
+def parse_all_nodes(filename, node_map):
+    """Parse all nodes in CamFlow data. @filename is the file path of
+    the CamFlow data to parse. @node_map contains the mappings of all
+    CamFlow nodes to their hashed attributes. """
+    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing nodes in CamFlow data from {}'.format(filename)
+    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+    with open(filename, 'r') as f:
+        # each line in CamFlow data could contain multiple
+        # provenance nodes, we call @parse_nodes routine.
+        for line in f:
+            pb.update()  # for progress tracking
+            parse_nodes(line, node_map)
+    f.close()
+    pb.close()
+
+
+def parse_all_edges(inputfile, outputfile, node_map, noencode):
+    """Parse all edges (including their timestamp) from CamFlow data file @inputfile
+    to an @outputfile. Before this function is called, parse_all_nodes should be called
+    to populate the @node_map for all nodes in the CamFlow file. If @noencode is set,
+    we do not hash the nodes' original UUIDs generated by CamFlow to integers. This
+    function returns the total number of valid edge parsed from CamFlow dataset.
+
+    The output edgelist has the following format for each line, if -s is not set:
+        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>
+    If -s is set, each line would look like:
+        <source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>:<timestamp_stats>"""
+    total_edges = 0
+    smallest_timestamp = None
+    # scan through the entire file to find the smallest timestamp from all the edges.
+    # this step is only needed if we need to add some statistical information.
+    if CONSOLE_ARGUMENTS.stats:
+        description = '\x1b[6;30;42m[STATUS]\x1b[0m Scanning edges in CamFlow data from {}'.format(inputfile)
+        pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+        with open(inputfile, 'r') as f:
+            for line in f:
+                pb.update()
+                json_object = json.loads(line)
+
+                if "used" in json_object:
+                    used = json_object["used"]
+                    for uid in used:
+                        if "prov:type" not in used[uid]:
+                            continue
+                        if "cf:date" not in used[uid]:
+                            continue
+                        if "prov:entity" not in used[uid]:
+                            continue
+                        if "prov:activity" not in used[uid]:
+                            continue
+                        srcUUID = used[uid]["prov:entity"]
+                        dstUUID = used[uid]["prov:activity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = used[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasGeneratedBy" in json_object:
+                    wasGeneratedBy = json_object["wasGeneratedBy"]
+                    for uid in wasGeneratedBy:
+                        if "prov:type" not in wasGeneratedBy[uid]:
+                            continue
+                        if "cf:date" not in wasGeneratedBy[uid]:
+                            continue
+                        if "prov:entity" not in wasGeneratedBy[uid]:
+                            continue
+                        if "prov:activity" not in wasGeneratedBy[uid]:
+                            continue
+                        srcUUID = wasGeneratedBy[uid]["prov:activity"]
+                        dstUUID = wasGeneratedBy[uid]["prov:entity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasGeneratedBy[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasInformedBy" in json_object:
+                    wasInformedBy = json_object["wasInformedBy"]
+                    for uid in wasInformedBy:
+                        if "prov:type" not in wasInformedBy[uid]:
+                            continue
+                        if "cf:date" not in wasInformedBy[uid]:
+                            continue
+                        if "prov:informant" not in wasInformedBy[uid]:
+                            continue
+                        if "prov:informed" not in wasInformedBy[uid]:
+                            continue
+                        srcUUID = wasInformedBy[uid]["prov:informant"]
+                        dstUUID = wasInformedBy[uid]["prov:informed"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasInformedBy[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasDerivedFrom" in json_object:
+                    wasDerivedFrom = json_object["wasDerivedFrom"]
+                    for uid in wasDerivedFrom:
+                        if "prov:type" not in wasDerivedFrom[uid]:
+                            continue
+                        if "cf:date" not in wasDerivedFrom[uid]:
+                            continue
+                        if "prov:usedEntity" not in wasDerivedFrom[uid]:
+                            continue
+                        if "prov:generatedEntity" not in wasDerivedFrom[uid]:
+                            continue
+                        srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
+                        dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasDerivedFrom[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+
+                if "wasAssociatedWith" in json_object:
+                    wasAssociatedWith = json_object["wasAssociatedWith"]
+                    for uid in wasAssociatedWith:
+                        if "prov:type" not in wasAssociatedWith[uid]:
+                            continue
+                        if "cf:date" not in wasAssociatedWith[uid]:
+                            continue
+                        if "prov:agent" not in wasAssociatedWith[uid]:
+                            continue
+                        if "prov:activity" not in wasAssociatedWith[uid]:
+                            continue
+                        srcUUID = wasAssociatedWith[uid]["prov:agent"]
+                        dstUUID = wasAssociatedWith[uid]["prov:activity"]
+                        if srcUUID not in node_map:
+                            continue
+                        if dstUUID not in node_map:
+                            continue
+                        timestamp_str = wasAssociatedWith[uid]["cf:date"]
+                        ts = time.mktime(datetime.datetime.strptime(timestamp_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                        if smallest_timestamp == None or ts < smallest_timestamp:
+                            smallest_timestamp = ts
+        f.close()
+        pb.close()
+
+    # we will go through the CamFlow data (again) and output edgelist to a file
+    output = open(outputfile, "w+")
+    description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in CamFlow data from {}'.format(inputfile)
+    pb = tqdm(desc=description, mininterval=1.0, unit=" recs")
+    with open(inputfile, 'r') as f:
+        for line in f:
+            pb.update()
+            json_object = json.loads(line)
+
+            if "used" in json_object:
+                used = json_object["used"]
+                for uid in used:
+                    if "prov:type" not in used[uid]:
+                        # an edge must have a type; if not,
+                        # we will have to skip the edge. Log
+                        # this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "used"
+                    # cf:id is used as logical timestamp to order edges
+                    if "cf:id" not in used[uid]:
+                        # an edge must have a logical timestamp;
+                        # if not we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = used[uid]["cf:id"]
+                    if "prov:entity" not in used[uid]:
+                        # an edge's source node must exist;
+                        # if not, we will have to skip the
+                        # edge. Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record without source UUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in used[uid]:
+                        # an edge's destination node must exist;
+                        # if not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record without destination UUID: {}".format(used[uid]["prov:type"],
+                                                                                            uid))
+                        continue
+                    srcUUID = used[uid]["prov:entity"]
+                    dstUUID = used[uid]["prov:activity"]
+                    # both source and destination node must
+                    # exist in @node_map; if not, we will
+                    # have to skip the edge. Log this issue
+                    # if verbose is set.
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record with an unseen srcUUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug(
+                                "edge (used/{}) record with an unseen dstUUID: {}".format(used[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in used[uid]:
+                        # an edge must have a timestamp; if
+                        # not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        # we only record @adjusted_ts if we need
+                        # to record stats of CamFlow dataset.
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = used[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in used[uid]:
+                        # an edge must have a jiffies timestamp; if
+                        # not, we will have to skip the edge.
+                        # Log this issue if verbose is set.
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (used) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        # we only record @jiffies if
+                        # the option is set
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = used[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp, adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal, edgetype, timestamp))
+
+            if "wasGeneratedBy" in json_object:
+                wasGeneratedBy = json_object["wasGeneratedBy"]
+                for uid in wasGeneratedBy:
+                    if "prov:type" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasGeneratedBy"
+                    if "cf:id" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasGeneratedBy[uid]["cf:id"]
+                    if "prov:entity" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record without source UUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record without destination UUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasGeneratedBy[uid]["prov:activity"]
+                    dstUUID = wasGeneratedBy[uid]["prov:entity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record with an unseen srcUUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy/{}) record with an unsen dstUUID: {}".format(
+                                wasGeneratedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasGeneratedBy[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasGeneratedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasGeneratedBy) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasGeneratedBy[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasInformedBy" in json_object:
+                wasInformedBy = json_object["wasInformedBy"]
+                for uid in wasInformedBy:
+                    if "prov:type" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasInformedBy"
+                    if "cf:id" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasInformedBy[uid]["cf:id"]
+                    if "prov:informant" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record without source UUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    if "prov:informed" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record without destination UUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasInformedBy[uid]["prov:informant"]
+                    dstUUID = wasInformedBy[uid]["prov:informed"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record with an unseen srcUUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy/{}) record with an unseen dstUUID: {}".format(
+                                wasInformedBy[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasInformedBy[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasInformedBy[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasInformedBy) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasInformedBy[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasDerivedFrom" in json_object:
+                wasDerivedFrom = json_object["wasDerivedFrom"]
+                for uid in wasDerivedFrom:
+                    if "prov:type" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasDerivedFrom"
+                    if "cf:id" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without logical timestamp: {}".format(uid))
+                            continue
+                    else:
+                        timestamp = wasDerivedFrom[uid]["cf:id"]
+                    if "prov:usedEntity" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record without source UUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    if "prov:generatedEntity" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record without destination UUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasDerivedFrom[uid]["prov:usedEntity"]
+                    dstUUID = wasDerivedFrom[uid]["prov:generatedEntity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen srcUUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom/{}) record with an unseen dstUUID: {}".format(
+                                wasDerivedFrom[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasDerivedFrom[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasDerivedFrom[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasDerivedFrom) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasDerivedFrom[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+
+            if "wasAssociatedWith" in json_object:
+                wasAssociatedWith = json_object["wasAssociatedWith"]
+                for uid in wasAssociatedWith:
+                    if "prov:type" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without type: {}".format(uid))
+                        continue
+                    else:
+                        edgetype = "wasAssociatedWith"
+                    if "cf:id" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without logical timestamp: {}".format(uid))
+                        continue
+                    else:
+                        timestamp = wasAssociatedWith[uid]["cf:id"]
+                    if "prov:agent" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record without source UUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    if "prov:activity" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record without destination UUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    srcUUID = wasAssociatedWith[uid]["prov:agent"]
+                    dstUUID = wasAssociatedWith[uid]["prov:activity"]
+                    if srcUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen srcUUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        srcVal = node_map[srcUUID]
+                    if dstUUID not in node_map:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith/{}) record with an unseen dstUUID: {}".format(
+                                wasAssociatedWith[uid]["prov:type"], uid))
+                        continue
+                    else:
+                        dstVal = node_map[dstUUID]
+                    if "cf:date" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without timestamp: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            ts_str = wasAssociatedWith[uid]["cf:date"]
+                            ts = time.mktime(datetime.datetime.strptime(ts_str, "%Y:%m:%dT%H:%M:%S").timetuple())
+                            adjusted_ts = ts - smallest_timestamp
+                    if "cf:jiffies" not in wasAssociatedWith[uid]:
+                        if CONSOLE_ARGUMENTS.verbose:
+                            logging.debug("edge (wasAssociatedWith) record without jiffies: {}".format(uid))
+                        continue
+                    else:
+                        if CONSOLE_ARGUMENTS.jiffies:
+                            jiffies = wasAssociatedWith[uid]["cf:jiffies"]
+                    total_edges += 1
+                    if noencode:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(srcUUID, dstUUID, srcVal, dstVal, edgetype, timestamp))
+                    else:
+                        if CONSOLE_ARGUMENTS.stats:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  adjusted_ts))
+                        elif CONSOLE_ARGUMENTS.jiffies:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal,
+                                                                  dstVal, edgetype, timestamp,
+                                                                  jiffies))
+                        else:
+                            output.write(
+                                "{}\t{}\t{}:{}:{}:{}\n".format(hashgen([srcUUID]), hashgen([dstUUID]), srcVal, dstVal,
+                                                               edgetype, timestamp))
+    f.close()
+    output.close()
+    pb.close()
+    return total_edges
+
+def read_single_graph(file_name, threshold):
+    graph = []
+    edge_cnt = 0
+    with open(file_name, 'r') as f:
+        for line in f:
+            try:
+                edge = line.strip().split("\t")
+                new_edge = [edge[0], edge[1]]
+                attributes = edge[2].strip().split(":")
+                source_node_type = attributes[0]
+                destination_node_type = attributes[1]
+                edge_type = attributes[2]
+                edge_order = attributes[3]
+
+                new_edge.append(source_node_type)
+                new_edge.append(destination_node_type)
+                new_edge.append(edge_type)
+                new_edge.append(edge_order)
+                graph.append(new_edge)
+                edge_cnt += 1
+            except:
+                print("{}".format(line))
+    f.close()
+    graph.sort(key=lambda e: e[5])
+    if len(graph) <= threshold:
+        return graph
+    else:
+        return graph[:threshold]
+
+
+def process_graph(name, threshold):
+    graph = read_single_graph(name, threshold)
+    result_graph = nx.DiGraph()
+    cnt = 0
+    for num, edge in enumerate(graph):
+        cnt += 1
+        src, dst, src_type, dst_type, edge_type = edge[:5]
+        if True:# src_type in valid_node_type and dst_type in valid_node_type:
+            if not result_graph.has_node(src):
+                result_graph.add_node(src, type=src_type)
+            if not result_graph.has_node(dst):
+                result_graph.add_node(dst, type=dst_type)
+            if not result_graph.has_edge(src, dst):
+                result_graph.add_edge(src, dst, type=edge_type)
+                if bidirection:
+                    result_graph.add_edge(dst, src, type='reverse_{}'.format(edge_type))
+    return cnt, result_graph
+
+
+node_type_list = []
+edge_type_list = []
+node_type_dict = {}
+edge_type_dict = {}
+
+
+def format_graph(g, name):
+    new_g = nx.DiGraph()
+    node_map = {}
+    node_cnt = 0
+    for n in g.nodes:
+        node_map[n] = node_cnt
+        new_g.add_node(node_cnt, type=g.nodes[n]['type'])
+        node_cnt += 1
+    for e in g.edges:
+        new_g.add_edge(node_map[e[0]], node_map[e[1]], type=g.edges[e]['type'])
+    for n in new_g.nodes:
+        node_type = new_g.nodes[n]['type']
+        if not node_type in node_type_dict:
+            node_type_list.append(node_type)
+            node_type_dict[node_type] = 1
+        else:
+            node_type_dict[node_type] += 1
+    for e in new_g.edges:
+        edge_type = new_g.edges[e]['type']
+        if not edge_type in edge_type_dict:
+            edge_type_list.append(edge_type)
+            edge_type_dict[edge_type] = 1
+        else:
+            edge_type_dict[edge_type] += 1
+    for n in new_g.nodes:
+        new_g.nodes[n]['type'] = node_type_list.index(new_g.nodes[n]['type'])
+    for e in new_g.edges:
+        new_g.edges[e]['type'] = edge_type_list.index(new_g.edges[e]['type'])
+    with open('{}.json'.format(name), 'w', encoding='utf-8') as f:
+        json.dump(nx.node_link_data(new_g), f)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Convert CamFlow JSON to Unicorn edgelist')
+    args = parser.parse_args()
+    args.stats = False
+    args.verbose = False
+    args.jiffies = False
+    args.input = '../data/wget/raw/'
+    args.output = '../data/wget/processed/'
+    args.final_output = '../data/wget/final/'
+    args.noencode = False
+    if not os.path.exists(args.input):
+        os.mkdir(args.input)
+    if not os.path.exists(args.output):
+        os.mkdir(args.output)
+    if not os.path.exists(args.final_output):
+        os.mkdir(args.final_output)
+    CONSOLE_ARGUMENTS = args
+
+    if args.verbose:
+        logging.basicConfig(filename=args.log, level=logging.DEBUG)
+    cnt = 0
+    for fname in os.listdir(args.input):
+        cnt += 1
+        node_map = dict()
+        parse_all_nodes(args.input + '/{}'.format(fname), node_map)
+        total_edges = parse_all_edges(args.input + '/{}'.format(fname), args.output + '/{}.log'.format(cnt), node_map,
+                                      args.noencode)
+        if args.stats:
+            total_nodes = len(node_map)
+            stats = open(args.stats_file + '/{}.log'.format(cnt), "a+")
+            stats.write("{},{},{}\n".format(args.input + '/{}'.format(fname), total_nodes, total_edges))
+
+    bidirection = False
+    threshold = 10000000 # infinity
+    interaction_dict = []
+    graph_cnt = 0
+    result_graphs = []
+    input = args.output
+    base = args.final_output
+
+    line_cnt = 0
+    for i in tqdm(range(cnt)):
+        single_cnt, result_graph = process_graph('{}{}.log'.format(input, i + 1), threshold)
+        format_graph(result_graph, '{}{}'.format(base, i))
+        line_cnt += single_cnt
+
+    print(line_cnt // 150)
+    print(len(node_type_list))
+    print(node_type_dict)
+    print(len(edge_type_list))
+    print(edge_type_dict)
+