import logging

import fedml
from data.data_loader import load_partition_data, load_batch_level_dataset_main, darpa_split
from fedml import FedMLRunner
from trainer.magic_trainer import MagicTrainer
from trainer.magic_aggregator import MagicWgetAggregator
from model.autoencoder import build_model
from utils.config import build_args
from trainer.magic_trainer import MagicTrainer
from trainer.magic_aggregator import MagicWgetAggregator
from trainer.single_trainer import train_single
from utils.loaddata import load_batch_level_dataset, load_metadata



def generate_dataset(name, number):
    (
        train_data_num,
        val_data_num,
        test_data_num,
        train_data_global,
        val_data_global,
        test_data_global,
        data_local_num_dict,
        train_data_local_dict,
        val_data_local_dict,
        test_data_local_dict,
    ) = load_partition_data(None, number, name) 
    dataset = [
        train_data_num,
        test_data_num,
        train_data_global,
        test_data_global,
        data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        len(train_data_global),
    ]
    
    if (name == "wget" or name == "streamspot"):
        
        return dataset, load_batch_level_dataset(name)
    else:
        return dataset, load_metadata(name) 
           
    
if __name__ == "__main__":
    # init FedML framework
    args = fedml.init()
    # init device
    device = fedml.device.get_device(args)
    name = args.dataset
    number = args.client_num_in_total
    
    dataset, metadata = generate_dataset(name, number)
    main_args = build_args()
    if (name == "wget"):
        main_args["num_hidden"] = 256
        main_args["max_epoch"] = 2
        main_args["num_layers"] = 4
        n_node_feat = metadata['n_feat']
        n_edge_feat = metadata['e_feat']
        main_args["n_dim"] = n_node_feat
        main_args["e_dim"] = n_edge_feat
    elif (name == "streamspot"):
        main_args["num_hidden"] = 256
        main_args["max_epoch"] = 5
        main_args["num_layers"] = 4
        n_node_feat = metadata['n_feat']
        n_edge_feat = metadata['e_feat']
        main_args["n_dim"] = n_node_feat
        main_args["e_dim"] = n_edge_feat
    else:
        main_args["num_hidden"] = 64
        main_args["max_epoch"] = 50
        main_args["num_layers"] = 3
        main_args["n_dim"] = metadata["node_feature_dim"]
        main_args["e_dim"] = metadata["edge_feature_dim"]
    
    model = build_model(main_args)
    #train_single(main_args, model, data)
    trainer = MagicTrainer(model, args, name)
    aggregator = MagicWgetAggregator(model, args, name)
    fedml_runner = FedMLRunner(args, device, dataset, model, trainer, aggregator)
    fedml_runner.run()
    # start training
    #darpa_split("theia")