diff --git a/trainer/magic_trainer.py b/trainer/magic_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..99c9a5b49e31a631f30053db0c3cf0623569b2eb
--- /dev/null
+++ b/trainer/magic_trainer.py
@@ -0,0 +1,186 @@
+import logging
+import os
+import random
+import torch
+import warnings
+from tqdm import tqdm
+import numpy as np
+import torch
+import wandb
+from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
+
+from fedml.core import ClientTrainer
+from model.autoencoder import build_model
+from torch.utils.data.sampler import SubsetRandomSampler
+from dgl.dataloading import GraphDataLoader
+from model.train import batch_level_train
+from utils.utils import set_random_seed, create_optimizer
+from utils.poolers import Pooling
+from utils.config import build_args
+from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata
+from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
+from model.autoencoder import build_model
+from data.data_loader import transform_data
+from utils.loaddata import load_batch_level_dataset
+
+# Trainer for MoleculeNet. The evaluation metric is ROC-AUC
+def extract_dataloaders(entries, batch_size):
+    random.shuffle(entries)
+    train_idx = torch.arange(len(entries))
+    train_sampler = SubsetRandomSampler(train_idx)
+    train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler)
+    return train_loader
+
+
+
+class MagicTrainer(ClientTrainer):
+    def __init__(self, model, args, name):
+        super().__init__(model, args)
+        self.name = name
+        self.max = 0
+    	
+    def get_model_params(self):
+        return self.model.cpu().state_dict()
+
+    def set_model_params(self, model_parameters):
+        logging.info("set_model_params")
+        self.model.load_state_dict(model_parameters)
+
+    def train(self, train_data, device, args):
+        test_data = None
+        args = build_args()
+        if (self.name == "wget"):
+            args["num_hidden"] = 256
+            args["max_epoch"] = 2
+            args["num_layers"] = 4
+            batch_size = 1
+        elif (self.name == "streamspot"):
+            args["num_hidden"] = 256
+            args["max_epoch"] = 5
+            args["num_layers"] = 4
+            batch_size = 12
+        else:
+            args["num_hidden"] = 64
+            args["max_epoch"] = 50
+            args["num_layers"] = 3
+        
+        max_test_score = 0
+        best_model_params = {}       
+        if (self.name == 'wget' or self.name == 'streamspot'):
+ 
+            dataset = load_batch_level_dataset(self.name)
+            data = transform_data(train_data)
+            n_node_feat = dataset['n_feat']
+            n_edge_feat = dataset['e_feat']
+            graphs = data['dataset']
+            train_index = data['train_index']
+            args["n_dim"] = n_node_feat
+            args["e_dim"] = n_edge_feat
+        #self.model = build_model(args)
+            self.model = self.model.to(device)
+            optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"])
+            self.model = batch_level_train(self.model, graphs, (extract_dataloaders(train_index, batch_size)),
+                                  optimizer, args["max_epoch"], device, n_node_feat,   n_edge_feat)
+            test_score, _ = self.test(test_data, device, args)
+        else:
+            
+            metadata = load_metadata(self.name)
+            args["n_dim"] = metadata['node_feature_dim']
+            args["e_dim"] = metadata['edge_feature_dim']
+            self.model = self.model.to(device)
+            self.model.train()
+            optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"])
+            epoch_iter = tqdm(range(args["max_epoch"]))
+            n_train = len(train_data[0])
+            input("start?")
+            for epoch in epoch_iter:
+                epoch_loss = 0.0
+                for i in range(n_train):
+                    g = train_data[0][i]
+                    self.model.train()
+                    loss  = self.model(g)
+                    loss /= n_train
+                    optimizer.zero_grad()
+                    epoch_loss += loss.item()
+                    loss.backward(retain_graph=True)
+                    optimizer.step()
+                    del g
+            epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
+        if (self.name == 'wget' or self.name == 'streamspot'):    
+            test_score, _ = self.test(test_data, device, args)
+            if test_score > self.max:
+                self.max = test_score
+                best_model_params = {
+                k: v.cpu() for k, v in self.model.state_dict().items()
+                }
+        else:
+            self.max = 0
+            best_model_params = {
+            k: v.cpu() for k, v in self.model.state_dict().items()
+            }
+            
+     
+
+        return self.max, best_model_params
+
+    def test(self, test_data, device, args):
+        if (self.name == 'wget' or self.name == 'streamspot'):
+            logging.info("----------test--------")
+            args = build_args()        
+            model = self.model
+            model.eval()
+            model.to(device)
+            pooler = Pooling(args["pooling"])
+            dataset = load_batch_level_dataset(self.name)
+            n_node_feat = dataset['n_feat']
+            n_edge_feat = dataset['e_feat']
+            args["n_dim"] = n_node_feat
+            args["e_dim"] = n_edge_feat
+            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"],  args["e_dim"])
+        else:
+            metadata = load_metadata(self.name)
+            args["n_dim"] = metadata['node_feature_dim']
+            args["e_dim"] = metadata['edge_feature_dim']
+            model = self.model.to(device)
+            model.eval()
+            malicious, _ = metadata['malicious']
+            n_train = metadata['n_train']
+            n_test = metadata['n_test']
+            with torch.no_grad():
+                x_train = []
+                for i in range(n_train):
+                   g = load_entity_level_dataset(self.name, 'train', i).to(device)
+                   x_train.append(model.embed(g).cpu().detach().numpy())
+                   del g
+                x_train = np.concatenate(x_train, axis=0)
+                skip_benign = 0
+            x_test = []
+            for i in range(n_test):
+                g = load_entity_level_dataset(self.name, 'test', i).to(device)
+                # Exclude training samples from the test set
+                if i != n_test - 1:
+                    skip_benign += g.number_of_nodes()
+                x_test.append(model.embed(g).cpu().detach().numpy())
+                del g
+            x_test = np.concatenate(x_test, axis=0)
+
+            n = x_test.shape[0]
+            y_test = np.zeros(n)
+            y_test[malicious] = 1.0
+            malicious_dict = {}
+            for i, m in enumerate(malicious):
+                malicious_dict[m] = i
+
+            # Exclude training samples from the test set
+            test_idx = []
+            for i in range(x_test.shape[0]):
+                if i >= skip_benign or y_test[i] == 1.0:
+                    test_idx.append(i)
+            result_x_test = x_test[test_idx]
+            result_y_test = y_test[test_idx]
+            del x_test, y_test
+            test_auc, test_std, _, _ = evaluate_entity_level_using_knn(self.name, x_train, result_x_test,
+                                                                       result_y_test)
+
+        return test_auc, model
+