diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py new file mode 100644 index 0000000000000000000000000000000000000000..c66087cb3ab2c67e121054838556b0d57253129b --- /dev/null +++ b/image_ref/main_ray.py @@ -0,0 +1,263 @@ +import os +import tempfile + +from config import load_args_contrastive +from dataset_ref import load_data_duo +import torch +import torch.nn as nn +from model import Classification_model_duo_contrastive +import torch.optim as optim + + +#ray +from ray.air import RunConfig +from ray.tune.search.optuna import OptunaSearch +from ray import train, tune +from ray.train import Checkpoint +from ray.tune.schedulers import ASHAScheduler + +def train_model(config,args): + # load data + data_train, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir, + base_dir_val=args.dataset_val_dir, + base_dir_test=args.dataset_test_dir, + batch_size=args.batch_size, + ref_dir=args.dataset_ref_dir, + noise_threshold=config['noise'], + positive_prop=config['positive_prop'], sampler=config['sampler']) + + # load model + model = Classification_model_duo_contrastive(model=args.model, n_class=2) + + # move parameters to GPU + model.double() + device = "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + if torch.cuda.device_count() > 1: + print(type(model)) + net = torch.nn.DataParallel(model) + print(type(net)) + model.to(device) + + if config['optimizer']=='Adam' : + optimizer = optim.Adam(model.parameters(), lr=config["lr"]) + elif config['optimizer']=='SGD' : + optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9) + # init training + loss_function = nn.CrossEntropyLoss() + + # Load existing checkpoint through `get_checkpoint()` API. + if train.get_checkpoint(): + loaded_checkpoint = train.get_checkpoint() + with loaded_checkpoint.as_directory() as loaded_checkpoint_dir: + model_state, optimizer_state = torch.load( + os.path.join(loaded_checkpoint_dir, "checkpoint.pt") + ) + net.load_state_dict(model_state) + optimizer.load_state_dict(optimizer_state) + + # train model + for e in range(args.epoches): + + #train loss + model.train() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = True + + for imaer, imana, img_ref, label in data_train: + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + img_ref = img_ref.cuda() + label = label.cuda() + if torch.cuda.device_count() > 1: + pred_logits = model.module.forward(imaer, imana, img_ref) + else: + pred_logits = model.forward(imaer, imana, img_ref) + + + pred_class = torch.argmax(pred_logits, dim=1) + acc += (pred_class == label).sum().item() + loss = loss_function(pred_logits, label) + losses += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses_train = losses / len(data_train.dataset) + acc_train = acc / len(data_train.dataset) + + #validation loss + model.eval() + losses = 0. + acc = 0. + acc_contrastive = 0. + for param in model.parameters(): + param.requires_grad = False + + for imaer, imana, img_ref, label in data_val_batch: + imaer = imaer.transpose(0, 1) + imana = imana.transpose(0, 1) + img_ref = img_ref.transpose(0, 1) + label = label.transpose(0, 1) + label = label.squeeze() + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + img_ref = img_ref.cuda() + label = label.cuda() + label_class = torch.argmin(label).data.cpu().numpy() + + if torch.cuda.device_count() > 1: + pred_logits = model.module.forward(imaer, imana, img_ref) + else: + pred_logits = model.forward(imaer, imana, img_ref) + + pred_class = torch.argmax(pred_logits[:, 0]).tolist() + acc_contrastive += ( + torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item() + acc += (pred_class == label_class) + loss = loss_function(pred_logits, label) + losses += loss.item() + losses_val = losses / (label.shape[0] * len(data_val_batch.dataset)) + acc_val = acc / (len(data_val_batch.dataset)) + acc_contrastive_val = acc_contrastive / (label.shape[0] * len(data_val_batch.dataset)) + + # Here we save a checkpoint. It is automatically registered with + # Ray Tune and will potentially be accessed through in ``get_checkpoint()`` + # in future iterations. + # Note to save a file like checkpoint, you still need to put it under a directory + # to construct a checkpoint. + with tempfile.TemporaryDirectory( + dir='lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/checkpoints') as temp_checkpoint_dir: + path = os.path.join(temp_checkpoint_dir, "checkpoint.pt") + + torch.save( + (model.state_dict(), optimizer.state_dict()), path + ) + checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) + print(checkpoint.path) + train.report( + {"train loss": losses_train, "train contrastive acc": acc_train,"val loss": losses_val,"val acc": acc_val,"val contrastive acc": acc_contrastive_val,}, + checkpoint=checkpoint,) + print("Finished Training") + + +def test_model(best_result, args): + # load data + _, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir, + base_dir_val=args.dataset_val_dir, + base_dir_test=args.dataset_test_dir, + batch_size=args.batch_size, + ref_dir=args.dataset_ref_dir, + noise_threshold=best_result.config['noise'], + positive_prop=best_result.config['positive_prop'], sampler=best_result.config['sampler']) + + # load model + model = Classification_model_duo_contrastive(model=args.model, n_class=2) + model.double() + # load weight + checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt") + + model_state, optimizer_state = torch.load(checkpoint_path) + model.load_state_dict(model_state) + + # move parameters to GPU + device = "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + if torch.cuda.device_count() > 1: + print(type(model)) + net = torch.nn.DataParallel(model) + print(type(net)) + model.to(device) + # init training + loss_function = nn.CrossEntropyLoss() + model.eval() + losses = 0. + acc = 0. + acc_contrastive = 0. + for param in model.parameters(): + param.requires_grad = False + + for imaer, imana, img_ref, label in data_val_batch: + imaer = imaer.transpose(0, 1) + imana = imana.transpose(0, 1) + img_ref = img_ref.transpose(0, 1) + label = label.transpose(0, 1) + label = label.squeeze() + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + img_ref = img_ref.cuda() + label = label.cuda() + label_class = torch.argmin(label).data.cpu().numpy() + pred_logits = model.forward(imaer, imana, img_ref) + pred_class = torch.argmax(pred_logits[:, 0]).tolist() + acc_contrastive += ( + torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item() + acc += (pred_class == label_class) + loss = loss_function(pred_logits, label) + losses += loss.item() + losses = losses / (label.shape[0] * len(data_val_batch.dataset)) + acc = acc / (len(data_val_batch.dataset)) + acc_contrastive = acc_contrastive / (label.shape[0] * len(data_val_batch.dataset)) + print("Best trial test set AsyncHyperBandSchedulerloss: loss {} acc {} acc_contrastive {}".format(losses,acc,acc_contrastive)) + +def main(args, gpus_per_trial=1): + config = { + "lr": tune.loguniform(1e-4, 1e-2), + "noise": tune.loguniform(0, 500), + "positive_prop": tune.uniform(0, 100), + "optimizer": tune.choice(['Adam', 'SGD']), + "sampler": tune.choice(['random', 'balanced']), + } + scheduler = ASHAScheduler( + max_t=100, + grace_period=20, + reduction_factor=3, + brackets=1, + ) + algo = OptunaSearch() + + tuner = tune.Tuner( + tune.with_resources( + tune.with_parameters(train_model, args=args), + resources={"cpu": 80, "gpu": gpus_per_trial} + ), + tune_config=tune.TuneConfig( + time_budget_s=3600 * 23.5, + search_alg=algo, + scheduler=scheduler, + num_samples=50, + metric="val loss", + mode='min', + + ), + run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/ray_results_test", + name="test_experiment_no_scheduler" + ), + param_space=config + + ) + results = tuner.fit() + + best_result = results.get_best_result("val loss", "min") + + print("Best trial config: {}".format(best_result.config)) + print("Best trial final validation loss: {}".format( + best_result.metrics["loss"])) + print("Best trial final validation accuracy: {}".format( + best_result.metrics["accuracy"])) + + test_model(best_result, args) + +if __name__ == '__main__': + args = load_args_contrastive() + print(args) + main(args)