-
Léo Schneider authored6ac8b84f
main_ray_tune.py 14.27 KiB
import os
import tempfile
import torch
import torch.optim as optim
from ray.air import RunConfig, CheckpointConfig
from ray.tune.search.ax import AxSearch
from ray.tune.search.bayesopt import BayesOptSearch
from ray.tune.search.bohb import TuneBOHB
from ray.tune.search.optuna import OptunaSearch
from ray.util.client import ray
import common_dataset
import dataloader
from config_common import load_args
from loss import masked_cos_sim
from model_custom import Model_Common_Transformer
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import HyperBandForBOHB, ASHAScheduler
def train_model(config, args):
net = Model_Common_Transformer(encoder_ff=int(config["encoder_ff"]),
decoder_rt_ff=int(config["decoder_rt_ff"]),
decoder_int_ff=int(config["decoder_int_ff"]),
n_head=int(config["n_head"]),
encoder_num_layer=int(config["encoder_num_layer"]),
decoder_int_num_layer=int(config["decoder_int_num_layer"]),
decoder_rt_num_layer=int(config["decoder_rt_num_layer"]),
drop_rate=float(config["drop_rate"]),
embedding_dim=int(config["embedding_dim"]),
acti=config["activation"],
norm=config["norm_first"])
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if torch.cuda.device_count() > 1:
print(type(net))
net = torch.nn.DataParallel(net)
print(type(net))
net.to(device)
criterion_rt = torch.nn.MSELoss()
criterion_intensity = masked_cos_sim
optimizer = optim.Adam(net.parameters(), lr=config["lr"])
# Load existing checkpoint through `get_checkpoint()` API.
if train.get_checkpoint():
loaded_checkpoint = train.get_checkpoint()
with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
model_state, optimizer_state = torch.load(
os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
)
net.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
if args.forward == 'both':
data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
path_val=args.dataset_val,
path_test=args.dataset_test,
batch_size=int(config["batch_size"]), length=25)
else:
data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train,
batch_size=int(config["batch_size"]), length=25)
for epoch in range(100): # loop over the dataset multiple times
running_loss = 0.0
epoch_steps = 0
for i, data in enumerate(data_train):
if args.forward == 'rt':
seq, rt = data
rt = rt.float()
if torch.cuda.is_available():
seq, rt = seq.cuda(), rt.cuda()
if torch.cuda.device_count() > 1:
pred_rt = net.module.forward_rt(seq)
else:
pred_rt = net.forward_rt(seq)
loss = criterion_rt(rt, pred_rt)
elif args.forward == 'int':
seq, charge, intensity = data
intensity = intensity.float()
if torch.cuda.is_available():
seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
if torch.cuda.device_count() > 1:
pred_int = net.module.forward_int(seq, charge)
else:
pred_int = net.forward_int(seq, charge)
loss = criterion_intensity(intensity, pred_int)
else:
seq, charge, rt, intensity = data
rt, intensity = rt.float(), intensity.float()
if torch.cuda.is_available():
seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
pred_rt, pred_int = net(seq, charge)
loss_rt = criterion_rt(rt, pred_rt)
loss_int = criterion_intensity(intensity, pred_int)
loss = loss_rt + loss_int
running_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print statistics
epoch_steps += 1
if i % 2000 == 1999: # print every 2000 mini-batches
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
running_loss / epoch_steps))
running_loss = 0.0
# Validation loss
val_loss = 0.0
val_steps = 0
for i, data in enumerate(data_val, 0):
with torch.no_grad():
if args.forward == 'rt':
seq, rt = data
rt = rt.float()
if torch.cuda.is_available():
seq, rt = seq.cuda(), rt.cuda()
if torch.cuda.device_count() > 1:
pred_rt = net.module.forward_rt(seq)
else:
pred_rt = net.forward_rt(seq)
loss = criterion_rt(rt, pred_rt)
elif args.forward == 'int':
seq, charge, intensity = data
intensity = intensity.float()
if torch.cuda.is_available():
seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
if torch.cuda.device_count() > 1:
pred_int = net.module.forward_int(seq, charge)
else:
pred_int = net.forward_int(seq, charge)
loss = criterion_intensity(intensity, pred_int)
else:
seq, charge, rt, intensity = data
rt, intensity = rt.float(), intensity.float()
if torch.cuda.is_available():
seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
pred_rt, pred_int = net(seq, charge)
loss_rt = criterion_rt(rt, pred_rt)
loss_int = criterion_intensity(intensity, pred_int)
loss = loss_rt + loss_int
val_loss += loss.item().numpy()
val_steps += 1
# Here we save a checkpoint. It is automatically registered with
# Ray Tune and will potentially be accessed through in ``get_checkpoint()``
# in future iterations.
# Note to save a file like checkpoint, you still need to put it under a directory
# to construct a checkpoint.
with tempfile.TemporaryDirectory(
dir='/gpfswork/rech/ute/ucg81ws/these/LC-MS-RT-prediction/checkpoints') as temp_checkpoint_dir:
path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
torch.save(
(net.state_dict(), optimizer.state_dict()), path
)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
print(checkpoint.path)
train.report(
{"loss": (val_loss / val_steps)},
checkpoint=checkpoint,
)
print("Finished Training")
def test_best_model(best_result, args):
best_trained_model = Model_Common_Transformer(encoder_ff=best_result.config["encoder_ff"],
decoder_rt_ff=best_result.config["decoder_rt_ff"],
decoder_int_ff=best_result.config["decoder_int_ff"]
, n_head=best_result.config["n_head"],
encoder_num_layer=best_result.config["batch_size"],
decoder_int_num_layer=best_result.config["decoder_int_num_layer"],
decoder_rt_num_layer=best_result.config["decoder_rt_num_layer"],
drop_rate=best_result.config["drop_rate"],
embedding_dim=best_result.config["embedding_dim"],
acti=best_result.config["activation"],
norm=best_result.config["norm_first"])
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if torch.cuda.device_count() > 1:
best_trained_model = torch.nn.DataParallel(best_trained_model)
best_trained_model.to(device)
criterion_rt = torch.nn.MSELoss()
criterion_intensity = masked_cos_sim
checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state)
if args.forward == 'both':
data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
path_val=args.dataset_val,
path_test=args.dataset_test,
batch_size=best_result.config["batch_size"],
length=25)
else:
data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train,
batch_size=best_result.config["batch_size"], length=25)
val_loss = 0
val_steps = 0
with torch.no_grad():
for data in data_test:
if args.forward == 'rt':
seq, rt = data
rt = rt.float()
if torch.cuda.is_available():
seq, rt = seq.cuda(), rt.cuda()
if torch.cuda.device_count() > 1:
pred_rt = best_trained_model.module.forward_rt(seq)
else:
pred_rt = best_trained_model.forward_rt(seq)
loss = criterion_rt(rt, pred_rt)
elif args.forward == 'int':
seq, charge, intensity = data
intensity = intensity.float()
if torch.cuda.is_available():
seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
if torch.cuda.device_count() > 1:
pred_int = best_trained_model.module.forward_int(seq, charge)
else:
pred_int = best_trained_model.forward_int(seq, charge)
loss = criterion_intensity(intensity, pred_int)
elif args.forward == 'both':
seq, charge, rt, intensity = data
rt, intensity = rt.float(), intensity.float()
if torch.cuda.is_available():
seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
pred_rt, pred_int = best_trained_model(seq, charge)
loss_rt = criterion_rt(rt, pred_rt)
loss_int = criterion_intensity(intensity, pred_int)
loss = loss_rt + loss_int
val_loss += loss.item().numpy()
val_steps += 1
print("Best trial test set AsyncHyperBandSchedulerloss: {}".format(val_loss))
def main(args, gpus_per_trial=1):
# config = {
# "encoder_num_layer": tune.choice([1]),
# "decoder_rt_num_layer": tune.choice([1]),
# "decoder_int_num_layer": tune.choice([1]),
# "embedding_dim": tune.choice([16, 64, 256, 1024]),
# "encoder_ff": tune.choice([512]),
# "decoder_rt_ff": tune.choice([512]),
# "decoder_int_ff": tune.choice([512]),
# "n_head": tune.choice([1]),
# "drop_rate": tune.choice([0.2]),
# "lr": tune.choice([1e-4]),
# "batch_size": tune.choice([1024]),
# }
config = {
"encoder_num_layer": tune.choice([2, 4, 8]),
"decoder_rt_num_layer": tune.choice([2, 4, 8]),
"decoder_int_num_layer": tune.choice([1]),
"embedding_dim": tune.choice([16, 64]),
"encoder_ff": tune.choice([512, 1024, 2048]),
"decoder_rt_ff": tune.choice([512, 1024, 2048]),
"decoder_int_ff": tune.choice([512]),
"n_head": tune.choice([1, 2, 4, 8, 16]),
"drop_rate": tune.choice([0.25]),
"lr": tune.loguniform(1e-4, 1e-2),
"batch_size": tune.choice([4096]),
"activation": tune.choice(['relu', 'gelu']),
"norm_first": tune.choice([True, False]),
}
scheduler = ASHAScheduler(
max_t=100,
grace_period=30,
reduction_factor=3,
brackets=1,
)
algo = OptunaSearch()
tuner = tune.Tuner(
tune.with_resources(
tune.with_parameters(train_model, args=args),
resources={"cpu": 80, "gpu": gpus_per_trial}
),
tune_config=tune.TuneConfig(
time_budget_s=3600 * 23,
search_alg=algo,
scheduler=scheduler,
num_samples=20,
metric='loss',
mode='min',
),
run_config=RunConfig(storage_path="/gpfswork/rech/ute/ucg81ws/these/LC-MS-RT-prediction/ray_results_test",
name="test_experiment_no_scheduler"
),
param_space=config
)
results = tuner.fit()
best_result = results.get_best_result("loss", "min")
print("Best trial config: {}".format(best_result.config))
print("Best trial final validation loss: {}".format(
best_result.metrics["loss"]))
print("Best trial final validation accuracy: {}".format(
best_result.metrics["accuracy"]))
test_best_model(best_result, args)
if __name__ == "__main__":
for i in range(torch.cuda.device_count()):
print(torch.cuda.get_device_properties(i).name)
torch.manual_seed(2809)
arg = load_args()
main(arg, gpus_per_trial=4)