diff --git a/trainer/utils/loaddata.py b/trainer/utils/loaddata.py
new file mode 100644
index 0000000000000000000000000000000000000000..41e7dfc03ee39adcc1f5729d59aa21124d981fff
--- /dev/null
+++ b/trainer/utils/loaddata.py
@@ -0,0 +1,197 @@
+import pickle as pkl
+import time
+import torch.nn.functional as F
+import dgl
+import networkx as nx
+import json
+from tqdm import tqdm
+import os
+
+
+class StreamspotDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(StreamspotDataset, self).__init__(name=name)
+        if name == 'streamspot':
+            path = './data/streamspot'
+            num_graphs = 600
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx + 1))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                if 300 <= idx <= 399:
+                    self.labels.append(1)
+                else:
+                    self.labels.append(0)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+
+class WgetDataset(dgl.data.DGLDataset):
+    def process(self):
+        pass
+
+    def __init__(self, name):
+        super(WgetDataset, self).__init__(name=name)
+        if name == 'wget':
+            path = './data/wget/final'
+            num_graphs = 150
+            self.graphs = []
+            self.labels = []
+            print('Loading {} dataset...'.format(name))
+            for i in tqdm(range(num_graphs)):
+                idx = i
+                g = dgl.from_networkx(
+                    nx.node_link_graph(json.load(open('{}/{}.json'.format(path, str(idx))))),
+                    node_attrs=['type'],
+                    edge_attrs=['type']
+                )
+                self.graphs.append(g)
+                if 0 <= idx <= 24:
+                    self.labels.append(1)
+                else:
+                    self.labels.append(0)
+        else:
+            raise NotImplementedError
+
+    def __getitem__(self, i):
+        return self.graphs[i], self.labels[i]
+
+    def __len__(self):
+        return len(self.graphs)
+
+
+def load_rawdata(name):
+    if name == 'streamspot':
+        path = './data/streamspot'
+        if os.path.exists(path + '/graphs.pkl'):
+            print('Loading processed {} dataset...'.format(name))
+            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
+        else:
+            raw_data = StreamspotDataset(name)
+            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
+    elif name == 'wget':
+        path = './data/wget'
+        if os.path.exists(path + '/graphs.pkl'):
+            print('Loading processed {} dataset...'.format(name))
+            raw_data = pkl.load(open(path + '/graphs.pkl', 'rb'))
+        else:
+            raw_data = WgetDataset(name)
+            pkl.dump(raw_data, open(path + '/graphs.pkl', 'wb'))
+    else:
+        raise NotImplementedError
+    return raw_data
+
+
+def load_batch_level_dataset(dataset_name):
+    dataset = load_rawdata(dataset_name)
+    graph, _ = dataset[0]
+    node_feature_dim = 0
+    for g, _ in dataset:
+        node_feature_dim = max(node_feature_dim, g.ndata["type"].max().item())
+    edge_feature_dim = 0
+    for g, _ in dataset:
+        edge_feature_dim = max(edge_feature_dim, g.edata["type"].max().item())
+    node_feature_dim += 1
+    edge_feature_dim += 1
+    full_dataset = [i for i in range(len(dataset))]
+    train_dataset = [i for i in range(len(dataset)) if dataset[i][1] == 0]
+    print('[n_graph, n_node_feat, n_edge_feat]: [{}, {}, {}]'.format(len(dataset), node_feature_dim, edge_feature_dim))
+
+    return {'dataset': dataset,
+            'train_index': train_dataset,
+            'full_index': full_dataset,
+            'n_feat': node_feature_dim,
+            'e_feat': edge_feature_dim}
+
+
+def transform_graph(g, node_feature_dim, edge_feature_dim):
+    new_g = g.clone()
+    new_g.ndata["attr"] = F.one_hot(g.ndata["type"].view(-1), num_classes=node_feature_dim).float()
+    new_g.edata["attr"] = F.one_hot(g.edata["type"].view(-1), num_classes=edge_feature_dim).float()
+    return new_g
+
+
+def preload_entity_level_dataset(path):
+    path = './data/' + path
+    if os.path.exists(path + '/metadata.json'):
+        pass
+    else:
+        print('transforming')
+        train_gs = [dgl.from_networkx(
+            nx.node_link_graph(g),
+            node_attrs=['type'],
+            edge_attrs=['type']
+        ) for g in pkl.load(open(path + '/train.pkl', 'rb'))]
+        print('transforming')
+        test_gs = [dgl.from_networkx(
+            nx.node_link_graph(g),
+            node_attrs=['type'],
+            edge_attrs=['type']
+        ) for g in pkl.load(open(path + '/test.pkl', 'rb'))]
+        malicious = pkl.load(open(path + '/malicious.pkl', 'rb'))
+
+        node_feature_dim = 0
+        for g in train_gs:
+            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
+        for g in test_gs:
+            node_feature_dim = max(g.ndata["type"].max().item(), node_feature_dim)
+        node_feature_dim += 1
+        edge_feature_dim = 0
+        for g in train_gs:
+            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
+        for g in test_gs:
+            edge_feature_dim = max(g.edata["type"].max().item(), edge_feature_dim)
+        edge_feature_dim += 1
+        result_test_gs = []
+        for g in test_gs:
+            g = transform_graph(g, node_feature_dim, edge_feature_dim)
+            result_test_gs.append(g)
+        result_train_gs = []
+        for g in train_gs:
+            g = transform_graph(g, node_feature_dim, edge_feature_dim)
+            result_train_gs.append(g)
+        metadata = {
+            'node_feature_dim': node_feature_dim,
+            'edge_feature_dim': edge_feature_dim,
+            'malicious': malicious,
+            'n_train': len(result_train_gs),
+            'n_test': len(result_test_gs)
+        }
+        with open(path + '/metadata.json', 'w', encoding='utf-8') as f:
+            json.dump(metadata, f)
+        for i, g in enumerate(result_train_gs):
+            with open(path + '/train{}.pkl'.format(i), 'wb') as f:
+                pkl.dump(g, f)
+        for i, g in enumerate(result_test_gs):
+            with open(path + '/test{}.pkl'.format(i), 'wb') as f:
+                pkl.dump(g, f)
+
+
+def load_metadata(path):
+    preload_entity_level_dataset(path)
+    with open('./data/' + path + '/metadata.json', 'r', encoding='utf-8') as f:
+        metadata = json.load(f)
+    return metadata
+
+
+def load_entity_level_dataset(path, t, n):
+    preload_entity_level_dataset(path)
+    with open('./data/' + path + '/{}{}.pkl'.format(t, n), 'rb') as f:
+        data = pkl.load(f)
+    return data