import os import tempfile import torch import torch.optim as optim from ray.air import RunConfig, CheckpointConfig from ray.tune.search.ax import AxSearch from ray.tune.search.bayesopt import BayesOptSearch from ray.tune.search.bohb import TuneBOHB from ray.tune.search.optuna import OptunaSearch from ray.util.client import ray import common_dataset import dataloader from config_common import load_args from loss import masked_cos_sim from model_custom import Model_Common_Transformer from ray import train, tune from ray.train import Checkpoint from ray.tune.schedulers import HyperBandForBOHB, ASHAScheduler def train_model(config, args): net = Model_Common_Transformer(encoder_ff=int(config["encoder_ff"]), decoder_rt_ff=int(config["decoder_rt_ff"]), decoder_int_ff=int(config["decoder_int_ff"]), n_head=int(config["n_head"]), encoder_num_layer=int(config["encoder_num_layer"]), decoder_int_num_layer=int(config["decoder_int_num_layer"]), decoder_rt_num_layer=int(config["decoder_rt_num_layer"]), drop_rate=float(config["drop_rate"]), embedding_dim=int(config["embedding_dim"]), acti=config["activation"], norm=config["norm_first"]) device = "cpu" if torch.cuda.is_available(): device = "cuda:0" if torch.cuda.device_count() > 1: print(type(net)) net = torch.nn.DataParallel(net) print(type(net)) net.to(device) criterion_rt = torch.nn.MSELoss() criterion_intensity = masked_cos_sim optimizer = optim.Adam(net.parameters(), lr=config["lr"]) # Load existing checkpoint through `get_checkpoint()` API. if train.get_checkpoint(): loaded_checkpoint = train.get_checkpoint() with loaded_checkpoint.as_directory() as loaded_checkpoint_dir: model_state, optimizer_state = torch.load( os.path.join(loaded_checkpoint_dir, "checkpoint.pt") ) net.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) if args.forward == 'both': data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train, path_val=args.dataset_val, path_test=args.dataset_test, batch_size=int(config["batch_size"]), length=25) else: data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train, batch_size=int(config["batch_size"]), length=25) for epoch in range(100): # loop over the dataset multiple times running_loss = 0.0 epoch_steps = 0 for i, data in enumerate(data_train): if args.forward == 'rt': seq, rt = data rt = rt.float() if torch.cuda.is_available(): seq, rt = seq.cuda(), rt.cuda() if torch.cuda.device_count() > 1: pred_rt = net.module.forward_rt(seq) else: pred_rt = net.forward_rt(seq) loss = criterion_rt(rt, pred_rt) elif args.forward == 'int': seq, charge, intensity = data intensity = intensity.float() if torch.cuda.is_available(): seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda() if torch.cuda.device_count() > 1: pred_int = net.module.forward_int(seq, charge) else: pred_int = net.forward_int(seq, charge) loss = criterion_intensity(intensity, pred_int) else: seq, charge, rt, intensity = data rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() pred_rt, pred_int = net(seq, charge) loss_rt = criterion_rt(rt, pred_rt) loss_int = criterion_intensity(intensity, pred_int) loss = loss_rt + loss_int running_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() # print statistics epoch_steps += 1 if i % 2000 == 1999: # print every 2000 mini-batches print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / epoch_steps)) running_loss = 0.0 # Validation loss val_loss = 0.0 val_steps = 0 for i, data in enumerate(data_val, 0): with torch.no_grad(): if args.forward == 'rt': seq, rt = data rt = rt.float() if torch.cuda.is_available(): seq, rt = seq.cuda(), rt.cuda() if torch.cuda.device_count() > 1: pred_rt = net.module.forward_rt(seq) else: pred_rt = net.forward_rt(seq) loss = criterion_rt(rt, pred_rt) elif args.forward == 'int': seq, charge, intensity = data intensity = intensity.float() if torch.cuda.is_available(): seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda() if torch.cuda.device_count() > 1: pred_int = net.module.forward_int(seq, charge) else: pred_int = net.forward_int(seq, charge) loss = criterion_intensity(intensity, pred_int) else: seq, charge, rt, intensity = data rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() pred_rt, pred_int = net(seq, charge) loss_rt = criterion_rt(rt, pred_rt) loss_int = criterion_intensity(intensity, pred_int) loss = loss_rt + loss_int val_loss += loss.item().numpy() val_steps += 1 # Here we save a checkpoint. It is automatically registered with # Ray Tune and will potentially be accessed through in ``get_checkpoint()`` # in future iterations. # Note to save a file like checkpoint, you still need to put it under a directory # to construct a checkpoint. with tempfile.TemporaryDirectory( dir='/gpfswork/rech/ute/ucg81ws/these/LC-MS-RT-prediction/checkpoints') as temp_checkpoint_dir: path = os.path.join(temp_checkpoint_dir, "checkpoint.pt") torch.save( (net.state_dict(), optimizer.state_dict()), path ) checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) print(checkpoint.path) train.report( {"loss": (val_loss / val_steps)}, checkpoint=checkpoint, ) print("Finished Training") def test_best_model(best_result, args): best_trained_model = Model_Common_Transformer(encoder_ff=best_result.config["encoder_ff"], decoder_rt_ff=best_result.config["decoder_rt_ff"], decoder_int_ff=best_result.config["decoder_int_ff"] , n_head=best_result.config["n_head"], encoder_num_layer=best_result.config["batch_size"], decoder_int_num_layer=best_result.config["decoder_int_num_layer"], decoder_rt_num_layer=best_result.config["decoder_rt_num_layer"], drop_rate=best_result.config["drop_rate"], embedding_dim=best_result.config["embedding_dim"], acti=best_result.config["activation"], norm=best_result.config["norm_first"]) device = "cpu" if torch.cuda.is_available(): device = "cuda:0" if torch.cuda.device_count() > 1: best_trained_model = torch.nn.DataParallel(best_trained_model) best_trained_model.to(device) criterion_rt = torch.nn.MSELoss() criterion_intensity = masked_cos_sim checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt") model_state, optimizer_state = torch.load(checkpoint_path) best_trained_model.load_state_dict(model_state) if args.forward == 'both': data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train, path_val=args.dataset_val, path_test=args.dataset_test, batch_size=best_result.config["batch_size"], length=25) else: data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train, batch_size=best_result.config["batch_size"], length=25) val_loss = 0 val_steps = 0 with torch.no_grad(): for data in data_test: if args.forward == 'rt': seq, rt = data rt = rt.float() if torch.cuda.is_available(): seq, rt = seq.cuda(), rt.cuda() if torch.cuda.device_count() > 1: pred_rt = best_trained_model.module.forward_rt(seq) else: pred_rt = best_trained_model.forward_rt(seq) loss = criterion_rt(rt, pred_rt) elif args.forward == 'int': seq, charge, intensity = data intensity = intensity.float() if torch.cuda.is_available(): seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda() if torch.cuda.device_count() > 1: pred_int = best_trained_model.module.forward_int(seq, charge) else: pred_int = best_trained_model.forward_int(seq, charge) loss = criterion_intensity(intensity, pred_int) elif args.forward == 'both': seq, charge, rt, intensity = data rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() pred_rt, pred_int = best_trained_model(seq, charge) loss_rt = criterion_rt(rt, pred_rt) loss_int = criterion_intensity(intensity, pred_int) loss = loss_rt + loss_int val_loss += loss.item().numpy() val_steps += 1 print("Best trial test set AsyncHyperBandSchedulerloss: {}".format(val_loss)) def main(args, gpus_per_trial=1): # config = { # "encoder_num_layer": tune.choice([1]), # "decoder_rt_num_layer": tune.choice([1]), # "decoder_int_num_layer": tune.choice([1]), # "embedding_dim": tune.choice([16, 64, 256, 1024]), # "encoder_ff": tune.choice([512]), # "decoder_rt_ff": tune.choice([512]), # "decoder_int_ff": tune.choice([512]), # "n_head": tune.choice([1]), # "drop_rate": tune.choice([0.2]), # "lr": tune.choice([1e-4]), # "batch_size": tune.choice([1024]), # } config = { "encoder_num_layer": tune.choice([2, 4, 8]), "decoder_rt_num_layer": tune.choice([2, 4, 8]), "decoder_int_num_layer": tune.choice([1]), "embedding_dim": tune.choice([16, 64]), "encoder_ff": tune.choice([512, 1024, 2048]), "decoder_rt_ff": tune.choice([512, 1024, 2048]), "decoder_int_ff": tune.choice([512]), "n_head": tune.choice([1, 2, 4, 8, 16]), "drop_rate": tune.choice([0.25]), "lr": tune.loguniform(1e-4, 1e-2), "batch_size": tune.choice([4096]), "activation": tune.choice(['relu', 'gelu']), "norm_first": tune.choice([True, False]), } scheduler = ASHAScheduler( max_t=100, grace_period=30, reduction_factor=3, brackets=1, ) algo = OptunaSearch() tuner = tune.Tuner( tune.with_resources( tune.with_parameters(train_model, args=args), resources={"cpu": 80, "gpu": gpus_per_trial} ), tune_config=tune.TuneConfig( time_budget_s=3600 * 23, search_alg=algo, scheduler=scheduler, num_samples=20, metric='loss', mode='min', ), run_config=RunConfig(storage_path="/gpfswork/rech/ute/ucg81ws/these/LC-MS-RT-prediction/ray_results_test", name="test_experiment_no_scheduler" ), param_space=config ) results = tuner.fit() best_result = results.get_best_result("loss", "min") print("Best trial config: {}".format(best_result.config)) print("Best trial final validation loss: {}".format( best_result.metrics["loss"])) print("Best trial final validation accuracy: {}".format( best_result.metrics["accuracy"])) test_best_model(best_result, args) if __name__ == "__main__": for i in range(torch.cuda.device_count()): print(torch.cuda.get_device_properties(i).name) torch.manual_seed(2809) arg = load_args() main(arg, gpus_per_trial=4)