Skip to content
Snippets Groups Projects
Commit f1d05b1d authored by Athmane Mansour Bahar's avatar Athmane Mansour Bahar
Browse files

Upload New File

parent a0b98d7a
No related branches found
No related tags found
No related merge requests found
eval.py 0 → 100644
import torch
import warnings
from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata
from model.autoencoder import build_model
from utils.poolers import Pooling
from utils.utils import set_random_seed
import numpy as np
from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
from utils.config import build_args
warnings.filterwarnings('ignore')
def main(main_args):
device = "cpu"
device = torch.device(device)
dataset_name = "trace"
if dataset_name in ['streamspot', 'wget']:
main_args.num_hidden = 256
main_args.num_layers = 4
else:
main_args["num_hidden"] = 64
main_args["num_layers"] = 3
set_random_seed(0)
if dataset_name == 'streamspot' or dataset_name == 'wget':
dataset = load_batch_level_dataset(dataset_name)
n_node_feat = dataset['n_feat']
n_edge_feat = dataset['e_feat']
main_args.n_dim = n_node_feat
main_args.e_dim = n_edge_feat
model = build_model(main_args)
model.load_state_dict(torch.load("./result/FedOpt-{}.pt".format(dataset_name), map_location=device))
model = model.to(device)
pooler = Pooling(main_args.pooling)
test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], args.dataset, main_args.n_dim,
main_args.e_dim)
else:
metadata = load_metadata(dataset_name)
main_args["n_dim"] = metadata['node_feature_dim']
main_args["e_dim"] = metadata['edge_feature_dim']
model = build_model(main_args)
model.load_state_dict(torch.load("./result/checkpoint-{}.pt".format(dataset_name), map_location=device))
model = model.to(device)
model.eval()
malicious, _ = metadata['malicious']
n_train = metadata['n_train']
n_test = metadata['n_test']
with torch.no_grad():
x_train = []
for i in range(n_train):
g = load_entity_level_dataset(dataset_name, 'train', i).to(device)
x_train.append(model.embed(g).cpu().detach().numpy())
del g
x_train = np.concatenate(x_train, axis=0)
skip_benign = 0
x_test = []
for i in range(n_test):
g = load_entity_level_dataset(dataset_name, 'test', i).to(device)
# Exclude training samples from the test set
if i != n_test - 1:
skip_benign += g.number_of_nodes()
x_test.append(model.embed(g).cpu().detach().numpy())
del g
x_test = np.concatenate(x_test, axis=0)
n = x_test.shape[0]
y_test = np.zeros(n)
y_test[malicious] = 1.0
malicious_dict = {}
for i, m in enumerate(malicious):
malicious_dict[m] = i
# Exclude training samples from the test set
test_idx = []
for i in range(x_test.shape[0]):
if i >= skip_benign or y_test[i] == 1.0:
test_idx.append(i)
result_x_test = x_test[test_idx]
result_y_test = y_test[test_idx]
del x_test, y_test
test_auc, test_std, _, _ = evaluate_entity_level_using_knn(dataset_name, x_train, result_x_test,
result_y_test)
print(f"#Test_AUC: {test_auc:.4f}±{test_std:.4f}")
return
if __name__ == '__main__':
args = build_args()
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment