Skip to content
Snippets Groups Projects
Commit fbcf8812 authored by Athmane Mansour Bahar's avatar Athmane Mansour Bahar
Browse files

Upload New File

parent d2da9340
No related branches found
No related tags found
No related merge requests found
import logging
import numpy as np
import torch
import wandb
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from utils.config import build_args
from fedml.core import ServerAggregator
from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
from utils.poolers import Pooling
# Trainer for MoleculeNet. The evaluation metric is ROC-AUC
from data.data_loader import load_batch_level_dataset_main, load_metadata, load_entity_level_dataset
from utils.loaddata import load_batch_level_dataset
class MagicWgetAggregator(ServerAggregator):
def __init__(self, model, args, name):
super().__init__(model, args)
self.name = name
def get_model_params(self):
return self.model.cpu().state_dict()
def set_model_params(self, model_parameters):
logging.info("set_model_params")
self.model.load_state_dict(model_parameters)
def test(self, test_data, device, args):
pass
def test_all(self, train_data_local_dict, test_data_local_dict, device, args) -> bool:
logging.info("----------test_on_the_server--------")
model_list, score_list = [], []
for client_idx in test_data_local_dict.keys():
test_data = test_data_local_dict[client_idx]
score, model = self._test(test_data, device, args)
for idx in range(len(model_list)):
self._compare_models(model, model_list[idx])
model_list.append(model)
score_list.append(score)
logging.info("Client {}, Test ROC-AUC score = {}".format(client_idx, score))
if args.enable_wandb:
wandb.log({"Client {} Test/ROC-AUC".format(client_idx): score})
avg_score = np.mean(np.array(score_list))
logging.info("Test ROC-AUC Score = {}".format(avg_score))
if args.enable_wandb:
wandb.log({"Test/ROC-AUC": avg_score})
return True
def _compare_models(self, model_1, model_2):
models_differ = 0
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
if torch.equal(key_item_1[1], key_item_2[1]):
pass
else:
models_differ += 1
if key_item_1[0] == key_item_2[0]:
logging.info("Mismatch found at", key_item_1[0])
else:
raise Exception
if models_differ == 0:
logging.info("Models match perfectly! :)")
def _test(self, test_data, device, args):
args = build_args()
if (self.name == 'wget' or self.name == 'streamspot'):
logging.info("----------test--------")
model = self.model
model.eval()
model.to(device)
pooler = Pooling(args["pooling"])
dataset = load_batch_level_dataset(self.name)
n_node_feat = dataset['n_feat']
n_edge_feat = dataset['e_feat']
args["n_dim"] = n_node_feat
args["e_dim"] = n_edge_feat
test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"], args["e_dim"])
else:
torch.save(self.model.state_dict(), "./result/FedAvg-4client-{}.pt".format(self.name))
metadata = load_metadata(self.name)
args["n_dim"] = metadata['node_feature_dim']
args["e_dim"] = metadata['edge_feature_dim']
model = self.model.to(device)
model.eval()
malicious, _ = metadata['malicious']
n_train = metadata['n_train']
n_test = metadata['n_test']
with torch.no_grad():
x_train = []
for i in range(n_train):
g = load_entity_level_dataset(self.name, 'train', i).to(device)
x_train.append(model.embed(g).cpu().detach().numpy())
del g
x_train = np.concatenate(x_train, axis=0)
skip_benign = 0
x_test = []
for i in range(n_test):
g = load_entity_level_dataset(self.name, 'test', i).to(device)
# Exclude training samples from the test set
if i != n_test - 1:
skip_benign += g.number_of_nodes()
x_test.append(model.embed(g).cpu().detach().numpy())
del g
x_test = np.concatenate(x_test, axis=0)
n = x_test.shape[0]
y_test = np.zeros(n)
y_test[malicious] = 1.0
malicious_dict = {}
for i, m in enumerate(malicious):
malicious_dict[m] = i
# Exclude training samples from the test set
test_idx = []
for i in range(x_test.shape[0]):
if i >= skip_benign or y_test[i] == 1.0:
test_idx.append(i)
result_x_test = x_test[test_idx]
result_y_test = y_test[test_idx]
del x_test, y_test
test_auc, test_std, _, _ = evaluate_entity_level_using_knn(self.name, x_train, result_x_test,
result_y_test)
torch.save(model.state_dict(), "./result/FedAvg-{}.pt".format(self.name))
return test_auc, model
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment