From f59dbc73a94c16fe8cb64e8871754f87257fa743 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Fri, 18 Apr 2025 09:59:45 +0200 Subject: [PATCH] fix : weight Xent loss gpu add : ray tune for base model (for leo C) --- dataset/dataset.py | 17 ++- image_ref/main_ray.py | 6 +- main_ray.py | 243 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 258 insertions(+), 8 deletions(-) create mode 100644 main_ray.py diff --git a/dataset/dataset.py b/dataset/dataset.py index 054b697..9f84e63 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -201,11 +201,18 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0): train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform) val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform) - generator1 = torch.Generator().manual_seed(42) - indices = torch.randperm(len(train_dataset), generator=generator1) - val_size = len(train_dataset) // 5 - train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size]) - val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:]) + classes_numbers = [51, 12, 9, 10, 86, 231, 20, 13, 24, 96, 11, 39, 11] + classes_names = ["Citrobacter freundii", "Citrobacter koseri", "Enterobacter asburiae", "Enterobacter cloacae", + "Enterobacter hormaechei", "Escherichia coli", "Klebsiella aerogenes", "Klebsiella michiganensis", + "Klebsiella oxytoca", "Klebsiella pneumoniae", "Klebsiella quasipneumoniae", "Proteus mirabilis", + "Salmonella enterica"] + repart_weights = [classes_names[i] for i in range(len(classes_names)) for k in + range(classes_numbers[i])] + iTrain, iVal = train_test_split(range(len(train_dataset)), test_size=0.2, shuffle=shuffle, stratify=repart_weights, + random_state=42) + train_dataset = torch.utils.data.Subset(train_dataset, iTrain) + val_dataset = torch.utils.data.Subset(val_dataset, iVal) + data_loader_train = data.DataLoader( dataset=train_dataset, diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index 9f33926..03d9016 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -43,9 +43,7 @@ def train_model(config,args): if torch.cuda.is_available(): device = "cuda:0" if torch.cuda.device_count() > 1: - print(type(model)) net = torch.nn.DataParallel(model) - print(type(net)) model.to(device) if config['optimizer']=='Adam' : @@ -54,7 +52,9 @@ def train_model(config,args): optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9) # init training n_class = len(data_train.dataset.classes) - loss_function = nn.CrossEntropyLoss(weight=torch.Tensor([1/n_class,1-1/n_class])) + weight = torch.Tensor([1/n_class,1-1/n_class]) + weight.to(device) + loss_function = nn.CrossEntropyLoss(weight=weight) # Load existing checkpoint through `get_checkpoint()` API. if train.get_checkpoint(): diff --git a/main_ray.py b/main_ray.py new file mode 100644 index 0000000..86cf087 --- /dev/null +++ b/main_ray.py @@ -0,0 +1,243 @@ +import os +import tempfile + +from config.config import load_args +from dataset.dataset import load_data_duo +import torch +import torch.nn as nn +from models.model import Classification_model_duo +import torch.optim as optim + + +#ray +from ray.air import RunConfig +from ray.tune.search.optuna import OptunaSearch +from ray import train, tune +from ray.train import Checkpoint +from ray.tune.schedulers import ASHAScheduler + +def train_model(config,args): + # load data + + + data_train, data_test = load_data_duo(dataset_dir=args.dataset_dir, + batch_size=args.batch_size, + noise_threshold=config['noise'], + ) + + # load model + model = Classification_model_duo(model=args.model, n_class=len(data_train.dataset.classes)) + + # move parameters to GPU + model.double() + device = "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + if torch.cuda.device_count() > 1: + net = torch.nn.DataParallel(model) + model.to(device) + + optimizer = optim.Adam(model.parameters(), lr=config["lr"],betas=(config['beta1'],config['beta2'])) + + # init training + if config['loss']=='base': + loss_function = nn.CrossEntropyLoss() + elif config['loss']== 'weighed': + classes_numbers = [51, 12, 9, 10, 86, 231, 20, 13, 24, 96, 11, 39, 11] + loss_weights = torch.tensor([1/n for n in classes_numbers]) + loss_weights.to(device) + loss_function = nn.CrossEntropyLoss(loss_weights) + # Load existing checkpoint through `get_checkpoint()` API. + if train.get_checkpoint(): + loaded_checkpoint = train.get_checkpoint() + with loaded_checkpoint.as_directory() as loaded_checkpoint_dir: + model_state, optimizer_state = torch.load( + os.path.join(loaded_checkpoint_dir, "checkpoint.pt") + ) + net.load_state_dict(model_state) + optimizer.load_state_dict(optimizer_state) + + # train model + for e in range(args.epoches): + + #train loss + model.train() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = True + + for imaer, imana, label in data_train: + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + label = label.cuda() + pred_logits = model.forward(imaer, imana) + pred_class = torch.argmax(pred_logits, dim=1) + acc += (pred_class == label).sum().item() + loss = loss_function(pred_logits, label) + losses += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses_train = losses / len(data_train.dataset) + acc_train = acc / len(data_train.dataset) + + #validation loss + model.eval() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = False + + for imaer, imana, label in data_test: + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + label = label.cuda() + pred_logits = model.forward(imaer, imana) + pred_class = torch.argmax(pred_logits, dim=1) + acc += (pred_class == label).sum().item() + loss = loss_function(pred_logits, label) + losses += loss.item() + losses_val = losses / (label.shape[0] * len(data_test.dataset)) + acc_val = acc / (len(data_test.dataset)) + + # Here we save a checkpoint. It is automatically registered with + # Ray Tune and will potentially be accessed through in ``get_checkpoint()`` + # in future iterations. + # Note to save a file like checkpoint, you still need to put it under a directory + # to construct a checkpoint. + with tempfile.TemporaryDirectory( + dir='/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/checkpoints') as temp_checkpoint_dir: + path = os.path.join(temp_checkpoint_dir, "checkpoint.pt") + + torch.save( + (model.state_dict(), optimizer.state_dict()), path + ) + checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) + print(checkpoint.path) + train.report( + {"train loss": losses_train, "train cont acc": acc_train,"val loss": losses_val,"val acc": acc_val}, + checkpoint=checkpoint,) + print("Finished Training") + + +def test_model(best_result, args): + + if best_result.config['res_count_thr']=='none': + ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref' + elif best_result.config['res_count_thr']=='10': + ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_10' + else : + ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_5' + + # load data + _, data_test = load_data_duo(base_dir = args.dataset_dir, + noise_threshold=best_result.config['noise']) + + # load model + model = Classification_model_duo(model=args.model, n_class=len(data_test.dataset.classes)) + model.double() + # load weight + checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt") + + model_state, optimizer_state = torch.load(checkpoint_path) + model.load_state_dict(model_state) + + # move parameters to GPU + device = "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + if torch.cuda.device_count() > 1: + print(type(model)) + net = torch.nn.DataParallel(model) + print(type(net)) + model.to(device) + # init training + loss_function = nn.CrossEntropyLoss() + model.eval() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = False + + for imaer, imana, label in data_test: + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + label = label.cuda() + pred_logits = model.forward(imaer, imana) + pred_class = torch.argmax(pred_logits, dim=1) + acc += (pred_class == label).sum().item() + loss = loss_function(pred_logits, label) + losses += loss.item() + losses_val = losses / (label.shape[0] * len(data_test.dataset)) + acc_val = acc / (len(data_test.dataset)) + print("Best trial test set AsyncHyperBandSchedulerloss: loss {} acc {} ".format(losses_val,acc_val)) + +def main(args, gpus_per_trial=1): + + + ''' Espace d'optimisation des HP : + lr : faire varier d'un facteur 10 au dessus en dessous : de 10-4 à 10-2 +beta1 : de 0.8 à 0.99 +beta2 : de 0.99 à 0.9999 +Le noise, de 0 à 10k +essayer avec une loss pondérée ou pas pondérée + + ''' + config = { + "lr": tune.loguniform(1e-4, 1e-2), + "noise": tune.loguniform(1, 10000), + "beta1": tune.uniform(0.8, 0.99), + "beta2": tune.uniform(0.99, 0.9999), + "loss": tune.choice(['base', 'weighed']), + } + scheduler = ASHAScheduler( + max_t=100, + grace_period=3, + reduction_factor=3, + brackets=1, + ) + algo = OptunaSearch() + + tuner = tune.Tuner( + tune.with_resources( + tune.with_parameters(train_model, args=args), + resources={"cpu": 20, "gpu": gpus_per_trial} + ), + tune_config=tune.TuneConfig( + time_budget_s=3600 * 19.5, + search_alg=algo, + scheduler=scheduler, + num_samples=-1, + metric="val loss", + mode='min', + + ), + run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/ray_results_test", + name="base_experiment" + ), + param_space=config + + ) + results = tuner.fit() + + best_result = results.get_best_result("val loss", "min") + + print("Best trial config: {}".format(best_result.config)) + print("Best trial final validation loss: {}".format( + best_result.metrics["loss"])) + print("Best trial final validation accuracy: {}".format( + best_result.metrics["accuracy"])) + + test_model(best_result, args) + +if __name__ == '__main__': + args = load_args() + print(args) + main(args) -- GitLab