diff --git a/trainer/utils/trace_parser.py b/trainer/utils/trace_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a91d424a3eea0bbec87ffb635a0d98b27d72f7
--- /dev/null
+++ b/trainer/utils/trace_parser.py
@@ -0,0 +1,288 @@
+import argparse
+import json
+import os
+import random
+import re
+
+from tqdm import tqdm
+import networkx as nx
+import pickle as pkl
+
+
+node_type_dict = {}
+edge_type_dict = {}
+node_type_cnt = 0
+edge_type_cnt = 0
+
+metadata = {
+    'trace':{
+        'train': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3'],
+        'test': ['ta1-trace-e3-official-1.json.0', 'ta1-trace-e3-official-1.json.1', 'ta1-trace-e3-official-1.json.2', 'ta1-trace-e3-official-1.json.3', 'ta1-trace-e3-official-1.json.4']
+    },
+    'theia':{
+            'train': ['ta1-theia-e3-official-6r.json', 'ta1-theia-e3-official-6r.json.1', 'ta1-theia-e3-official-6r.json.2', 'ta1-theia-e3-official-6r.json.3'],
+            'test': ['ta1-theia-e3-official-6r.json.8']
+    },
+    'cadets':{
+            'train': ['ta1-cadets-e3-official.json','ta1-cadets-e3-official.json.1', 'ta1-cadets-e3-official.json.2', 'ta1-cadets-e3-official-2.json.1'],
+            'test': ['ta1-cadets-e3-official-2.json']
+    }
+}
+
+
+pattern_uuid = re.compile(r'uuid\":\"(.*?)\"')
+pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
+pattern_type = re.compile(r'type\":\"(.*?)\"')
+pattern_time = re.compile(r'timestampNanos\":(.*?),')
+pattern_file_name = re.compile(r'map\":\{\"path\":\"(.*?)\"')
+pattern_process_name = re.compile(r'map\":\{\"name\":\"(.*?)\"')
+pattern_netflow_object_name = re.compile(r'remoteAddress\":\"(.*?)\"')
+
+adversarial = 'None'
+
+
+def read_single_graph(dataset, malicious, path, test=False):
+    global node_type_cnt, edge_type_cnt
+    g = nx.DiGraph()
+    print('converting {} ...'.format(path))
+    path = '../data/{}/'.format(dataset) + path + '.txt'
+    f = open(path, 'r')
+    lines = []
+    for l in f.readlines():
+        split_line = l.split('\t')
+        src, src_type, dst, dst_type, edge_type, ts = split_line
+        ts = int(ts)
+        if not test:
+            if src in malicious or dst in malicious:
+                if src in malicious and src_type != 'MemoryObject':
+                    continue
+                if dst in malicious and dst_type != 'MemoryObject':
+                    continue
+
+        if src_type not in node_type_dict:
+            node_type_dict[src_type] = node_type_cnt
+            node_type_cnt += 1
+        if dst_type not in node_type_dict:
+            node_type_dict[dst_type] = node_type_cnt
+            node_type_cnt += 1
+        if edge_type not in edge_type_dict:
+            edge_type_dict[edge_type] = edge_type_cnt
+            edge_type_cnt += 1
+        if 'READ' in edge_type or 'RECV' in edge_type or 'LOAD' in edge_type:
+            lines.append([dst, src, dst_type, src_type, edge_type, ts])
+        else:
+            lines.append([src, dst, src_type, dst_type, edge_type, ts])
+    lines.sort(key=lambda l: l[5])
+
+    node_map = {}
+    node_type_map = {}
+    node_cnt = 0
+    node_list = []
+    for l in lines:
+        src, dst, src_type, dst_type, edge_type = l[:5]
+        src_type_id = node_type_dict[src_type]
+        dst_type_id = node_type_dict[dst_type]
+        edge_type_id = edge_type_dict[edge_type]
+        if src not in node_map:
+            if test:
+                if adversarial == 'MFE' or adversarial == 'MCE':
+                    if src in malicious:
+                        src_type_id = node_type_dict['FILE_OBJECT_FILE']
+            if not test:
+                if adversarial == 'BFP':
+                    if src not in malicious:
+                        i = random.randint(1, 20)
+                        if i == 1:
+                            src_type_id = node_type_dict['NetFlowObject']
+            node_map[src] = node_cnt
+            g.add_node(node_cnt, type=src_type_id)
+            node_list.append(src)
+            node_type_map[src] = src_type
+            node_cnt += 1
+        if dst not in node_map:
+            if test:
+                if adversarial == 'MFE' or adversarial == 'MCE':
+                    if dst in malicious:
+                        dst_type_id = node_type_dict['FILE_OBJECT_FILE']
+            if not test:
+                if adversarial == 'BFP':
+                    if dst not in malicious:
+                        i = random.randint(1, 20)
+                        if i == 1:
+                            dst_type_id = node_type_dict['NetFlowObject']
+            node_map[dst] = node_cnt
+            g.add_node(node_cnt, type=dst_type_id)
+            node_type_map[dst] = dst_type
+            node_list.append(dst)
+            node_cnt += 1
+        if not g.has_edge(node_map[src], node_map[dst]):
+            g.add_edge(node_map[src], node_map[dst], type=edge_type_id)
+    if (adversarial == 'MSE' or adversarial == 'MCE') and test:
+        for i, node in enumerate(node_list):
+            if node in malicious:
+                while True:
+                    another_node = random.choice(node_list)
+                    if 'FILE_OBJECT' in node_type_map[node] and node_type_map[another_node] == 'SUBJECT_PROCESS':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_WRITE'])
+                        print(node, another_node)
+                        break
+                    if node_type_map[node] == 'SUBJECT_PROCESS' and node_type_map[another_node] == 'FILE_OBJECT_FILE':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_READ'])
+                        print(node, another_node)
+                        break
+                    if node_type_map[node] == 'NetFlowObject' and node_type_map[another_node] == 'SUBJECT_PROCESS':
+                        g.add_edge(node_map[another_node], node_map[node], type=edge_type_dict['EVENT_CONNECT'])
+                        print(node, another_node)
+                        break
+                    if not 'FILE_OBJECT' in node_type_map[node] and not node_type_map[node] == 'SUBJECT_PROCESS' and not node_type_map[node] == 'NetFlowObject':
+                        break
+    return node_map, g
+
+
+def preprocess_dataset(dataset):
+    id_nodetype_map = {}
+    id_nodename_map = {}
+    for file in os.listdir('../data/{}/'.format(dataset)):
+        if 'json' in file and not '.txt' in file and not 'names' in file and not 'types' in file and not 'metadata' in file:
+            print('reading {} ...'.format(file))
+            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
+            for line in tqdm(f):
+                if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line: continue
+                if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line: continue
+                if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line: continue
+                if len(pattern_uuid.findall(line)) == 0: print(line)
+                uuid = pattern_uuid.findall(line)[0]
+                subject_type = pattern_type.findall(line)
+
+                if len(subject_type) < 1:
+                    if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line:
+                        subject_type = 'MemoryObject'
+                    if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line:
+                        subject_type = 'NetFlowObject'
+                    if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line:
+                        subject_type = 'UnnamedPipeObject'
+                else:
+                    subject_type = subject_type[0]
+
+                if uuid == '00000000-0000-0000-0000-000000000000' or subject_type in ['SUBJECT_UNIT']:
+                    continue
+                id_nodetype_map[uuid] = subject_type
+                if 'FILE' in subject_type and len(pattern_file_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_file_name.findall(line)[0]
+                elif subject_type == 'SUBJECT_PROCESS' and len(pattern_process_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_process_name.findall(line)[0]
+                elif subject_type == 'NetFlowObject' and len(pattern_netflow_object_name.findall(line)) > 0:
+                    id_nodename_map[uuid] = pattern_netflow_object_name.findall(line)[0]
+    for key in metadata[dataset]:
+        for file in metadata[dataset][key]:
+            if os.path.exists('../data/{}/'.format(dataset) + file + '.txt'):
+                continue
+            f = open('../data/{}/'.format(dataset) + file, 'r', encoding='utf-8')
+            fw = open('../data/{}/'.format(dataset) + file + '.txt', 'w', encoding='utf-8')
+            print('processing {} ...'.format(file))
+            for line in tqdm(f):
+                if 'com.bbn.tc.schema.avro.cdm18.Event' in line:
+                    edgeType = pattern_type.findall(line)[0]
+                    timestamp = pattern_time.findall(line)[0]
+                    srcId = pattern_src.findall(line)
+
+                    if len(srcId) == 0: continue
+                    srcId = srcId[0]
+                    if not srcId in id_nodetype_map:
+                        continue
+                    srcType = id_nodetype_map[srcId]
+                    dstId1 = pattern_dst1.findall(line)
+                    if len(dstId1) > 0 and dstId1[0] != 'null':
+                        dstId1 = dstId1[0]
+                        if not dstId1 in id_nodetype_map:
+                            continue
+                        dstType1 = id_nodetype_map[dstId1]
+                        this_edge1 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId1) + '\t' + str(
+                            dstType1) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
+                        fw.write(this_edge1)
+
+                    dstId2 = pattern_dst2.findall(line)
+                    if len(dstId2) > 0 and dstId2[0] != 'null':
+                        dstId2 = dstId2[0]
+                        if not dstId2 in id_nodetype_map.keys():
+                            continue
+                        dstType2 = id_nodetype_map[dstId2]
+                        this_edge2 = str(srcId) + '\t' + str(srcType) + '\t' + str(dstId2) + '\t' + str(
+                            dstType2) + '\t' + str(edgeType) + '\t' + str(timestamp) + '\n'
+                        fw.write(this_edge2)
+            fw.close()
+            f.close()
+    if len(id_nodename_map) != 0:
+        fw = open('../data/{}/'.format(dataset) + 'names.json', 'w', encoding='utf-8')
+        json.dump(id_nodename_map, fw)
+    if len(id_nodetype_map) != 0:
+        fw = open('../data/{}/'.format(dataset) + 'types.json', 'w', encoding='utf-8')
+        json.dump(id_nodetype_map, fw)
+
+
+def read_graphs(dataset):
+    malicious_entities = '../data/{}/{}.txt'.format(dataset, dataset)
+    f = open(malicious_entities, 'r')
+    malicious_entities = set()
+    for l in f.readlines():
+        malicious_entities.add(l.lstrip().rstrip())
+
+    preprocess_dataset(dataset)
+    train_gs = []
+    for file in metadata[dataset]['train']:
+        _, train_g = read_single_graph(dataset, malicious_entities, file, False)
+        train_gs.append(train_g)
+    test_gs = []
+    test_node_map = {}
+    count_node = 0
+    for file in metadata[dataset]['test']:
+        node_map, test_g = read_single_graph(dataset, malicious_entities, file, True)
+        assert len(node_map) == test_g.number_of_nodes()
+        test_gs.append(test_g)
+        for key in node_map:
+            if key not in test_node_map:
+                test_node_map[key] = node_map[key] + count_node
+        count_node += test_g.number_of_nodes()
+
+    if os.path.exists('../data/{}/names.json'.format(dataset)) and os.path.exists('../data/{}/types.json'.format(dataset)):
+        with open('../data/{}/names.json'.format(dataset), 'r', encoding='utf-8') as f:
+            id_nodename_map = json.load(f)
+        with open('../data/{}/types.json'.format(dataset), 'r', encoding='utf-8') as f:
+            id_nodetype_map = json.load(f)
+        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
+        final_malicious_entities = []
+        malicious_names = []
+        for e in malicious_entities:
+            if e in test_node_map and e in id_nodetype_map and id_nodetype_map[e] != 'MemoryObject' and id_nodetype_map[e] != 'UnnamedPipeObject':
+                final_malicious_entities.append(test_node_map[e])
+                if e in id_nodename_map:
+                    malicious_names.append(id_nodename_map[e])
+                    f.write('{}\t{}\n'.format(e, id_nodename_map[e]))
+                else:
+                    malicious_names.append(e)
+                    f.write('{}\t{}\n'.format(e, e))
+    else:
+        f = open('../data/{}/malicious_names.txt'.format(dataset), 'w', encoding='utf-8')
+        final_malicious_entities = []
+        malicious_names = []
+        for e in malicious_entities:
+            if e in test_node_map:
+                final_malicious_entities.append(test_node_map[e])
+                malicious_names.append(e)
+                f.write('{}\t{}\n'.format(e, e))
+
+    pkl.dump((final_malicious_entities, malicious_names), open('../data/{}/malicious.pkl'.format(dataset), 'wb'))
+    pkl.dump([nx.node_link_data(train_g) for train_g in train_gs], open('../data/{}/train.pkl'.format(dataset), 'wb'))
+    pkl.dump([nx.node_link_data(test_g) for test_g in test_gs], open('../data/{}/test.pkl'.format(dataset), 'wb'))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='CDM Parser')
+    parser.add_argument("--dataset", type=str, default="trace")
+    args = parser.parse_args()
+    if args.dataset not in ['trace', 'theia', 'cadets']:
+        raise NotImplementedError
+    read_graphs(args.dataset)
+