From fbcf8812a663d7e0a1e3c23c01e4bb0a8cca3fd0 Mon Sep 17 00:00:00 2001 From: Athmane Mansour Bahar <ja_mansourbahar@esi.dz> Date: Thu, 15 Aug 2024 17:48:41 +0000 Subject: [PATCH] Upload New File --- trainer/magic_aggregator.py | 127 ++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 trainer/magic_aggregator.py diff --git a/trainer/magic_aggregator.py b/trainer/magic_aggregator.py new file mode 100644 index 0000000..a9d7f70 --- /dev/null +++ b/trainer/magic_aggregator.py @@ -0,0 +1,127 @@ +import logging + +import numpy as np +import torch +import wandb +from sklearn.metrics import roc_auc_score, precision_recall_curve, auc +from utils.config import build_args +from fedml.core import ServerAggregator +from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn +from utils.poolers import Pooling +# Trainer for MoleculeNet. The evaluation metric is ROC-AUC +from data.data_loader import load_batch_level_dataset_main, load_metadata, load_entity_level_dataset +from utils.loaddata import load_batch_level_dataset + +class MagicWgetAggregator(ServerAggregator): + def __init__(self, model, args, name): + super().__init__(model, args) + self.name = name + + def get_model_params(self): + return self.model.cpu().state_dict() + + def set_model_params(self, model_parameters): + logging.info("set_model_params") + self.model.load_state_dict(model_parameters) + + def test(self, test_data, device, args): + pass + + def test_all(self, train_data_local_dict, test_data_local_dict, device, args) -> bool: + logging.info("----------test_on_the_server--------") + + model_list, score_list = [], [] + for client_idx in test_data_local_dict.keys(): + test_data = test_data_local_dict[client_idx] + score, model = self._test(test_data, device, args) + for idx in range(len(model_list)): + self._compare_models(model, model_list[idx]) + model_list.append(model) + score_list.append(score) + logging.info("Client {}, Test ROC-AUC score = {}".format(client_idx, score)) + if args.enable_wandb: + wandb.log({"Client {} Test/ROC-AUC".format(client_idx): score}) + avg_score = np.mean(np.array(score_list)) + logging.info("Test ROC-AUC Score = {}".format(avg_score)) + if args.enable_wandb: + wandb.log({"Test/ROC-AUC": avg_score}) + return True + + def _compare_models(self, model_1, model_2): + models_differ = 0 + for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): + if torch.equal(key_item_1[1], key_item_2[1]): + pass + else: + models_differ += 1 + if key_item_1[0] == key_item_2[0]: + logging.info("Mismatch found at", key_item_1[0]) + else: + raise Exception + if models_differ == 0: + logging.info("Models match perfectly! :)") + + def _test(self, test_data, device, args): + args = build_args() + if (self.name == 'wget' or self.name == 'streamspot'): + logging.info("----------test--------") + + model = self.model + model.eval() + model.to(device) + pooler = Pooling(args["pooling"]) + dataset = load_batch_level_dataset(self.name) + n_node_feat = dataset['n_feat'] + n_edge_feat = dataset['e_feat'] + args["n_dim"] = n_node_feat + args["e_dim"] = n_edge_feat + test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"], args["e_dim"]) + else: + torch.save(self.model.state_dict(), "./result/FedAvg-4client-{}.pt".format(self.name)) + metadata = load_metadata(self.name) + args["n_dim"] = metadata['node_feature_dim'] + args["e_dim"] = metadata['edge_feature_dim'] + model = self.model.to(device) + model.eval() + malicious, _ = metadata['malicious'] + n_train = metadata['n_train'] + n_test = metadata['n_test'] + + with torch.no_grad(): + x_train = [] + for i in range(n_train): + g = load_entity_level_dataset(self.name, 'train', i).to(device) + x_train.append(model.embed(g).cpu().detach().numpy()) + del g + x_train = np.concatenate(x_train, axis=0) + skip_benign = 0 + x_test = [] + for i in range(n_test): + g = load_entity_level_dataset(self.name, 'test', i).to(device) + # Exclude training samples from the test set + if i != n_test - 1: + skip_benign += g.number_of_nodes() + x_test.append(model.embed(g).cpu().detach().numpy()) + del g + x_test = np.concatenate(x_test, axis=0) + + n = x_test.shape[0] + y_test = np.zeros(n) + y_test[malicious] = 1.0 + malicious_dict = {} + for i, m in enumerate(malicious): + malicious_dict[m] = i + + # Exclude training samples from the test set + test_idx = [] + for i in range(x_test.shape[0]): + if i >= skip_benign or y_test[i] == 1.0: + test_idx.append(i) + result_x_test = x_test[test_idx] + result_y_test = y_test[test_idx] + del x_test, y_test + test_auc, test_std, _, _ = evaluate_entity_level_using_knn(self.name, x_train, result_x_test, + result_y_test) + torch.save(model.state_dict(), "./result/FedAvg-{}.pt".format(self.name)) + return test_auc, model + -- GitLab