diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..968181a1cfaa6de944115c62b2b187376c734502
--- /dev/null
+++ b/train.py
@@ -0,0 +1,90 @@
+import os
+import random
+import torch
+import warnings
+from tqdm import tqdm
+from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata, transform_graph
+from model.autoencoder import build_model
+from torch.utils.data.sampler import SubsetRandomSampler
+from dgl.dataloading import GraphDataLoader
+import dgl
+from model.train import batch_level_train
+from utils.utils import set_random_seed, create_optimizer
+from utils.config import build_args
+warnings.filterwarnings('ignore')
+
+
+def extract_dataloaders(entries, batch_size):
+    random.shuffle(entries)
+    train_idx = torch.arange(len(entries))
+    train_sampler = SubsetRandomSampler(train_idx)
+    train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler)
+    return train_loader
+
+
+def main(main_args):
+    device = "cpu"
+    dataset_name =  "trace"
+    if dataset_name == 'streamspot':
+        main_args.num_hidden = 256
+        main_args.max_epoch = 5
+        main_args.num_layers = 4
+    elif dataset_name == 'wget':
+        main_args.num_hidden = 256
+        main_args.max_epoch = 2
+        main_args.num_layers = 4
+    else:
+        main_args["num_hidden"] = 64
+        main_args["max_epoch"] = 50
+        main_args["num_layers"] = 3
+    set_random_seed(0)
+
+    if dataset_name == 'streamspot' or dataset_name == 'wget':
+        if dataset_name == 'streamspot':
+            batch_size = 12
+        else:
+            batch_size = 1
+        dataset = load_batch_level_dataset(dataset_name)
+        n_node_feat = dataset['n_feat']
+        n_edge_feat = dataset['e_feat']
+        graphs = dataset['dataset']
+        train_index = dataset['train_index']
+        main_args.n_dim = n_node_feat
+        main_args.e_dim = n_edge_feat
+        model = build_model(main_args)
+        model = model.to(device)
+        optimizer = create_optimizer(main_args.optimizer, model, main_args.lr, main_args.weight_decay)
+        model = batch_level_train(model, graphs, (extract_dataloaders(train_index, batch_size)),
+                                  optimizer, main_args.max_epoch, device, main_args.n_dim, main_args.e_dim)
+        torch.save(model.state_dict(), "./checkpoints/checkpoint-{}.pt".format(dataset_name))
+    else:
+        metadata = load_metadata(dataset_name)
+        main_args["n_dim"] = metadata['node_feature_dim']
+        main_args["e_dim"] = metadata['edge_feature_dim']
+        model = build_model(main_args)
+        model = model.to(device)
+        model.train()
+        optimizer = create_optimizer(main_args["optimizer"], model, main_args["lr"], main_args["weight_decay"])
+        epoch_iter = tqdm(range(main_args["max_epoch"]))
+        n_train = metadata['n_train']
+        for epoch in epoch_iter:
+            epoch_loss = 0.0
+            for i in range(n_train):
+                g = load_entity_level_dataset(dataset_name, 'train', i).to(device)
+                model.train()
+                loss  = model(g)
+                loss /= n_train
+                optimizer.zero_grad()
+                epoch_loss += loss.item()
+                loss.backward()
+                optimizer.step()
+                del g
+            epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
+            torch.save(model.state_dict(), "./result/checkpoint-{}.pt".format(dataset_name))
+
+    return
+
+
+if __name__ == '__main__':
+    args = build_args()
+    main(args)