Skip to content
Snippets Groups Projects
Commit d2da9340 authored by Athmane Mansour Bahar's avatar Athmane Mansour Bahar
Browse files

Upload New File

parent ab0420f1
No related branches found
No related tags found
No related merge requests found
import logging
import os
import random
import torch
import warnings
from tqdm import tqdm
import numpy as np
import torch
import wandb
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from fedml.core import ClientTrainer
from model.autoencoder import build_model
from torch.utils.data.sampler import SubsetRandomSampler
from dgl.dataloading import GraphDataLoader
from model.train import batch_level_train
from utils.utils import set_random_seed, create_optimizer
from utils.poolers import Pooling
from utils.config import build_args
from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata
from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
from model.autoencoder import build_model
from data.data_loader import transform_data
from utils.loaddata import load_batch_level_dataset
# Trainer for MoleculeNet. The evaluation metric is ROC-AUC
def extract_dataloaders(entries, batch_size):
random.shuffle(entries)
train_idx = torch.arange(len(entries))
train_sampler = SubsetRandomSampler(train_idx)
train_loader = GraphDataLoader(entries, batch_size=batch_size, sampler=train_sampler)
return train_loader
class MagicTrainer(ClientTrainer):
def __init__(self, model, args, name):
super().__init__(model, args)
self.name = name
self.max = 0
def get_model_params(self):
return self.model.cpu().state_dict()
def set_model_params(self, model_parameters):
logging.info("set_model_params")
self.model.load_state_dict(model_parameters)
def train(self, train_data, device, args):
test_data = None
args = build_args()
if (self.name == "wget"):
args["num_hidden"] = 256
args["max_epoch"] = 2
args["num_layers"] = 4
batch_size = 1
elif (self.name == "streamspot"):
args["num_hidden"] = 256
args["max_epoch"] = 5
args["num_layers"] = 4
batch_size = 12
else:
args["num_hidden"] = 64
args["max_epoch"] = 50
args["num_layers"] = 3
max_test_score = 0
best_model_params = {}
if (self.name == 'wget' or self.name == 'streamspot'):
dataset = load_batch_level_dataset(self.name)
data = transform_data(train_data)
n_node_feat = dataset['n_feat']
n_edge_feat = dataset['e_feat']
graphs = data['dataset']
train_index = data['train_index']
args["n_dim"] = n_node_feat
args["e_dim"] = n_edge_feat
#self.model = build_model(args)
self.model = self.model.to(device)
optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"])
self.model = batch_level_train(self.model, graphs, (extract_dataloaders(train_index, batch_size)),
optimizer, args["max_epoch"], device, n_node_feat, n_edge_feat)
test_score, _ = self.test(test_data, device, args)
else:
metadata = load_metadata(self.name)
args["n_dim"] = metadata['node_feature_dim']
args["e_dim"] = metadata['edge_feature_dim']
self.model = self.model.to(device)
self.model.train()
optimizer = create_optimizer(args["optimizer"], self.model, args["lr"], args["weight_decay"])
epoch_iter = tqdm(range(args["max_epoch"]))
n_train = len(train_data[0])
input("start?")
for epoch in epoch_iter:
epoch_loss = 0.0
for i in range(n_train):
g = train_data[0][i]
self.model.train()
loss = self.model(g)
loss /= n_train
optimizer.zero_grad()
epoch_loss += loss.item()
loss.backward(retain_graph=True)
optimizer.step()
del g
epoch_iter.set_description(f"Epoch {epoch} | train_loss: {epoch_loss:.4f}")
if (self.name == 'wget' or self.name == 'streamspot'):
test_score, _ = self.test(test_data, device, args)
if test_score > self.max:
self.max = test_score
best_model_params = {
k: v.cpu() for k, v in self.model.state_dict().items()
}
else:
self.max = 0
best_model_params = {
k: v.cpu() for k, v in self.model.state_dict().items()
}
return self.max, best_model_params
def test(self, test_data, device, args):
if (self.name == 'wget' or self.name == 'streamspot'):
logging.info("----------test--------")
args = build_args()
model = self.model
model.eval()
model.to(device)
pooler = Pooling(args["pooling"])
dataset = load_batch_level_dataset(self.name)
n_node_feat = dataset['n_feat']
n_edge_feat = dataset['e_feat']
args["n_dim"] = n_node_feat
args["e_dim"] = n_edge_feat
test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], self.name ,args["n_dim"], args["e_dim"])
else:
metadata = load_metadata(self.name)
args["n_dim"] = metadata['node_feature_dim']
args["e_dim"] = metadata['edge_feature_dim']
model = self.model.to(device)
model.eval()
malicious, _ = metadata['malicious']
n_train = metadata['n_train']
n_test = metadata['n_test']
with torch.no_grad():
x_train = []
for i in range(n_train):
g = load_entity_level_dataset(self.name, 'train', i).to(device)
x_train.append(model.embed(g).cpu().detach().numpy())
del g
x_train = np.concatenate(x_train, axis=0)
skip_benign = 0
x_test = []
for i in range(n_test):
g = load_entity_level_dataset(self.name, 'test', i).to(device)
# Exclude training samples from the test set
if i != n_test - 1:
skip_benign += g.number_of_nodes()
x_test.append(model.embed(g).cpu().detach().numpy())
del g
x_test = np.concatenate(x_test, axis=0)
n = x_test.shape[0]
y_test = np.zeros(n)
y_test[malicious] = 1.0
malicious_dict = {}
for i, m in enumerate(malicious):
malicious_dict[m] = i
# Exclude training samples from the test set
test_idx = []
for i in range(x_test.shape[0]):
if i >= skip_benign or y_test[i] == 1.0:
test_idx.append(i)
result_x_test = x_test[test_idx]
result_y_test = y_test[test_idx]
del x_test, y_test
test_auc, test_std, _, _ = evaluate_entity_level_using_knn(self.name, x_train, result_x_test,
result_y_test)
return test_auc, model
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment