import logging import fedml from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split from fedml import FedMLRunner from trainer.magic_trainer import MagicTrainer from trainer.magic_aggregator import MagicWgetAggregator from model.autoencoder import build_model from utils.config import build_args from trainer.magic_trainer import MagicTrainer from trainer.magic_aggregator import MagicWgetAggregator from trainer.single_trainer import train_single from utils.loaddata import load_batch_level_dataset, load_metadata def generate_dataset(name, number): ( train_data_num, val_data_num, test_data_num, train_data_global, val_data_global, test_data_global, data_local_num_dict, train_data_local_dict, val_data_local_dict, test_data_local_dict, ) = load_partition_data(None, number, name) dataset = [ train_data_num, test_data_num, train_data_global, test_data_global, data_local_num_dict, train_data_local_dict, test_data_local_dict, len(train_data_global), ] if (name == "wget" or name == "streamspot"): return dataset, load_batch_level_dataset(name) else: return dataset, load_metadata(name) if __name__ == "__main__": # init FedML framework args = fedml.init() # init device device = fedml.device.get_device(args) name = args.dataset number = args.client_num_in_total dataset, metadata = generate_dataset(name, number) main_args = build_args() if (name == "wget"): main_args["num_hidden"] = 256 main_args["max_epoch"] = 2 main_args["num_layers"] = 4 n_node_feat = metadata['n_feat'] n_edge_feat = metadata['e_feat'] main_args["n_dim"] = n_node_feat main_args["e_dim"] = n_edge_feat elif (name == "streamspot"): main_args["num_hidden"] = 256 main_args["max_epoch"] = 5 main_args["num_layers"] = 4 n_node_feat = metadata['n_feat'] n_edge_feat = metadata['e_feat'] main_args["n_dim"] = n_node_feat main_args["e_dim"] = n_edge_feat else: main_args["num_hidden"] = 64 main_args["max_epoch"] = 50 main_args["num_layers"] = 3 main_args["n_dim"] = metadata["node_feature_dim"] main_args["e_dim"] = metadata["edge_feature_dim"] model = build_model(main_args) #train_single(main_args, model, data) trainer = MagicTrainer(model, args, name) aggregator = MagicWgetAggregator(model, args, name) fedml_runner = FedMLRunner(args, device, dataset, model, trainer, aggregator) fedml_runner.run() # start training #darpa_split("theia")