From fbcf8812a663d7e0a1e3c23c01e4bb0a8cca3fd0 Mon Sep 17 00:00:00 2001
From: Athmane Mansour Bahar <ja_mansourbahar@esi.dz>
Date: Thu, 15 Aug 2024 17:48:41 +0000
Subject: [PATCH] Upload New File

---
 trainer/magic_aggregator.py | 127 ++++++++++++++++++++++++++++++++++++
 1 file changed, 127 insertions(+)
 create mode 100644 trainer/magic_aggregator.py

diff --git a/trainer/magic_aggregator.py b/trainer/magic_aggregator.py
new file mode 100644
index 0000000..a9d7f70
--- /dev/null
+++ b/trainer/magic_aggregator.py
@@ -0,0 +1,127 @@
+import logging
+
+import numpy as np
+import torch
+import wandb
+from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
+from utils.config import build_args
+from fedml.core import ServerAggregator
+from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
+from utils.poolers import Pooling
+# Trainer for MoleculeNet. The evaluation metric is ROC-AUC
+from data.data_loader import load_batch_level_dataset_main, load_metadata, load_entity_level_dataset
+from utils.loaddata import  load_batch_level_dataset
+
+class MagicWgetAggregator(ServerAggregator):
+    def __init__(self, model, args, name):
+        super().__init__(model, args)
+        self.name = name
+        
+    def get_model_params(self):
+        return self.model.cpu().state_dict()
+
+    def set_model_params(self, model_parameters):
+        logging.info("set_model_params")
+        self.model.load_state_dict(model_parameters)
+
+    def test(self, test_data, device, args):
+        pass
+
+    def test_all(self, train_data_local_dict, test_data_local_dict, device, args) -> bool:
+        logging.info("----------test_on_the_server--------")
+
+        model_list, score_list = [], []
+        for client_idx in test_data_local_dict.keys():
+            test_data = test_data_local_dict[client_idx]
+            score, model = self._test(test_data, device, args)
+            for idx in range(len(model_list)):
+                self._compare_models(model, model_list[idx])
+            model_list.append(model)
+            score_list.append(score)
+            logging.info("Client {}, Test ROC-AUC score = {}".format(client_idx, score))
+            if args.enable_wandb:
+                wandb.log({"Client {} Test/ROC-AUC".format(client_idx): score})
+        avg_score = np.mean(np.array(score_list))
+        logging.info("Test ROC-AUC Score = {}".format(avg_score))
+        if args.enable_wandb:
+            wandb.log({"Test/ROC-AUC": avg_score})
+        return True
+
+    def _compare_models(self, model_1, model_2):
+        models_differ = 0
+        for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
+            if torch.equal(key_item_1[1], key_item_2[1]):
+                pass
+            else:
+                models_differ += 1
+                if key_item_1[0] == key_item_2[0]:
+                    logging.info("Mismatch found at", key_item_1[0])
+                else:
+                    raise Exception
+        if models_differ == 0:
+            logging.info("Models match perfectly! :)")
+
+    def _test(self, test_data, device, args):
+        args = build_args()           
+        if (self.name == 'wget' or self.name == 'streamspot'):
+            logging.info("----------test--------")
+                 
+            model = self.model
+            model.eval()
+            model.to(device)
+            pooler = Pooling(args["pooling"])
+            dataset = load_batch_level_dataset(self.name)
+            n_node_feat = dataset['n_feat']
+            n_edge_feat = dataset['e_feat']
+            args["n_dim"] = n_node_feat
+            args["e_dim"] = n_edge_feat
+            test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"],  args["e_dim"])
+        else:
+            torch.save(self.model.state_dict(), "./result/FedAvg-4client-{}.pt".format(self.name))
+            metadata = load_metadata(self.name)
+            args["n_dim"] = metadata['node_feature_dim']
+            args["e_dim"] = metadata['edge_feature_dim']
+            model = self.model.to(device)
+            model.eval()
+            malicious, _ = metadata['malicious']
+            n_train = metadata['n_train']
+            n_test = metadata['n_test']
+
+            with torch.no_grad():
+                x_train = []
+                for i in range(n_train):
+                   g = load_entity_level_dataset(self.name, 'train', i).to(device)
+                   x_train.append(model.embed(g).cpu().detach().numpy())
+                   del g
+                x_train = np.concatenate(x_train, axis=0)
+                skip_benign = 0
+            x_test = []
+            for i in range(n_test):
+                g = load_entity_level_dataset(self.name, 'test', i).to(device)
+                # Exclude training samples from the test set
+                if i != n_test - 1:
+                    skip_benign += g.number_of_nodes()
+                x_test.append(model.embed(g).cpu().detach().numpy())
+                del g
+            x_test = np.concatenate(x_test, axis=0)
+
+            n = x_test.shape[0]
+            y_test = np.zeros(n)
+            y_test[malicious] = 1.0
+            malicious_dict = {}
+            for i, m in enumerate(malicious):
+                malicious_dict[m] = i
+
+            # Exclude training samples from the test set
+            test_idx = []
+            for i in range(x_test.shape[0]):
+                if i >= skip_benign or y_test[i] == 1.0:
+                    test_idx.append(i)
+            result_x_test = x_test[test_idx]
+            result_y_test = y_test[test_idx]
+            del x_test, y_test
+            test_auc, test_std, _, _ = evaluate_entity_level_using_knn(self.name, x_train, result_x_test,
+                                                                       result_y_test)
+        torch.save(model.state_dict(), "./result/FedAvg-{}.pt".format(self.name))
+        return test_auc, model
+                                                    
-- 
GitLab