diff --git a/trainer/magic_trainer.py b/trainer/magic_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..99c9a5b49e31a631f30053db0c3cf0623569b2eb --- /dev/null +++ b/trainer/magic_trainer.py @@ -0,0 +1,186 @@ +import logging +import os +import random +import torch +import warnings +from tqdm import tqdm +import numpy as np +import torch +import wandb +from sklearn.metrics import roc_auc_score, precision_recall_curve, auc + +from fedml.core import ClientTrainer +from model.autoencoder import build_model +from torch.utils.data.sampler import SubsetRandomSampler +from dgl.dataloading import GraphDataLoader +from model.train import batch_level_train +from utils.utils import set_random_seed, create_optimizer +from utils.poolers import Pooling +from utils.config import build_args +from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata +from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn +from model.autoencoder import build_model +from data.data_loader import transform_data +from utils.loaddata import load_batch_level_dataset + +# Trainer for MoleculeNet. The evaluation metric is ROC-AUC +def extract_dataloaders(entries, batch_size): + random.shuffle(entries) + train_idx = torch.arange(len(entries)) + train_sampler = SubsetRandomSampler(train_idx) + train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler) + return train_loader + + + +class MagicTrainer(ClientTrainer): + def __init__(self, model, args, name): + super().__init__(model, args) + self.name = name + self.max = 0 + + def get_model_params(self): + return self.model.cpu().state_dict() + + def set_model_params(self, model_parameters): + logging.info("set_model_params") + self.model.load_state_dict(model_parameters) + + def train(self, train_data, device, args): + test_data = None + args = build_args() + if (self.name == "wget"): + args["num_hidden"] = 256 + args["max_epoch"] = 2 + args["num_layers"] = 4 + batch_size = 1 + elif (self.name == "streamspot"): + args["num_hidden"] = 256 + args["max_epoch"] = 5 + args["num_layers"] = 4 + batch_size = 12 + else: + args["num_hidden"] = 64 + args["max_epoch"] = 50 + args["num_layers"] = 3 + + max_test_score = 0 + best_model_params = {} + if (self.name == 'wget' or self.name == 'streamspot'): + + dataset = load_batch_level_dataset(self.name) + data = transform_data(train_data) + n_node_feat = dataset['n_feat'] + n_edge_feat = dataset['e_feat'] + graphs = data['dataset'] + train_index = data['train_index'] + args["n_dim"] = n_node_feat + args["e_dim"] = n_edge_feat + #self.model = build_model(args) + self.model = self.model.to(device) + optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"]) + self.model = batch_level_train(self.model, graphs, (extract_dataloaders(train_index, batch_size)), + optimizer, args["max_epoch"], device, n_node_feat, n_edge_feat) + test_score, _ = self.test(test_data, device, args) + else: + + metadata = load_metadata(self.name) + args["n_dim"] = metadata['node_feature_dim'] + args["e_dim"] = metadata['edge_feature_dim'] + self.model = self.model.to(device) + self.model.train() + optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"]) + epoch_iter = tqdm(range(args["max_epoch"])) + n_train = len(train_data[0]) + input("start?") + for epoch in epoch_iter: + epoch_loss = 0.0 + for i in range(n_train): + g = train_data[0][i] + self.model.train() + loss = self.model(g) + loss /= n_train + optimizer.zero_grad() + epoch_loss += loss.item() + loss.backward(retain_graph=True) + optimizer.step() + del g + epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}") + if (self.name == 'wget' or self.name == 'streamspot'): + test_score, _ = self.test(test_data, device, args) + if test_score > self.max: + self.max = test_score + best_model_params = { + k: v.cpu() for k, v in self.model.state_dict().items() + } + else: + self.max = 0 + best_model_params = { + k: v.cpu() for k, v in self.model.state_dict().items() + } + + + + return self.max, best_model_params + + def test(self, test_data, device, args): + if (self.name == 'wget' or self.name == 'streamspot'): + logging.info("----------test--------") + args = build_args() + model = self.model + model.eval() + model.to(device) + pooler = Pooling(args["pooling"]) + dataset = load_batch_level_dataset(self.name) + n_node_feat = dataset['n_feat'] + n_edge_feat = dataset['e_feat'] + args["n_dim"] = n_node_feat + args["e_dim"] = n_edge_feat + test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"], args["e_dim"]) + else: + metadata = load_metadata(self.name) + args["n_dim"] = metadata['node_feature_dim'] + args["e_dim"] = metadata['edge_feature_dim'] + model = self.model.to(device) + model.eval() + malicious, _ = metadata['malicious'] + n_train = metadata['n_train'] + n_test = metadata['n_test'] + with torch.no_grad(): + x_train = [] + for i in range(n_train): + g = load_entity_level_dataset(self.name, 'train', i).to(device) + x_train.append(model.embed(g).cpu().detach().numpy()) + del g + x_train = np.concatenate(x_train, axis=0) + skip_benign = 0 + x_test = [] + for i in range(n_test): + g = load_entity_level_dataset(self.name, 'test', i).to(device) + # Exclude training samples from the test set + if i != n_test - 1: + skip_benign += g.number_of_nodes() + x_test.append(model.embed(g).cpu().detach().numpy()) + del g + x_test = np.concatenate(x_test, axis=0) + + n = x_test.shape[0] + y_test = np.zeros(n) + y_test[malicious] = 1.0 + malicious_dict = {} + for i, m in enumerate(malicious): + malicious_dict[m] = i + + # Exclude training samples from the test set + test_idx = [] + for i in range(x_test.shape[0]): + if i >= skip_benign or y_test[i] == 1.0: + test_idx.append(i) + result_x_test = x_test[test_idx] + result_y_test = y_test[test_idx] + del x_test, y_test + test_auc, test_std, _, _ = evaluate_entity_level_using_knn(self.name, x_train, result_x_test, + result_y_test) + + return test_auc, model +