From f59dbc73a94c16fe8cb64e8871754f87257fa743 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Fri, 18 Apr 2025 09:59:45 +0200
Subject: [PATCH] fix : weight Xent loss gpu add : ray tune for base model (for
 leo C)

---
 dataset/dataset.py    |  17 ++-
 image_ref/main_ray.py |   6 +-
 main_ray.py           | 243 ++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 258 insertions(+), 8 deletions(-)
 create mode 100644 main_ray.py

diff --git a/dataset/dataset.py b/dataset/dataset.py
index 054b697..9f84e63 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -201,11 +201,18 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
 
     train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform)
     val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform)
-    generator1 = torch.Generator().manual_seed(42)
-    indices = torch.randperm(len(train_dataset), generator=generator1)
-    val_size = len(train_dataset) // 5
-    train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
-    val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
+    classes_numbers = [51, 12, 9, 10, 86, 231, 20, 13, 24, 96, 11, 39, 11]
+    classes_names = ["Citrobacter freundii", "Citrobacter koseri", "Enterobacter asburiae", "Enterobacter cloacae",
+     "Enterobacter hormaechei", "Escherichia coli", "Klebsiella aerogenes", "Klebsiella michiganensis",
+     "Klebsiella oxytoca", "Klebsiella pneumoniae", "Klebsiella quasipneumoniae", "Proteus mirabilis",
+     "Salmonella enterica"]
+    repart_weights = [classes_names[i] for i in range(len(classes_names)) for k in
+                      range(classes_numbers[i])]
+    iTrain, iVal = train_test_split(range(len(train_dataset)), test_size=0.2, shuffle=shuffle, stratify=repart_weights,
+                                    random_state=42)
+    train_dataset = torch.utils.data.Subset(train_dataset, iTrain)
+    val_dataset = torch.utils.data.Subset(val_dataset, iVal)
+
 
     data_loader_train = data.DataLoader(
         dataset=train_dataset,
diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py
index 9f33926..03d9016 100644
--- a/image_ref/main_ray.py
+++ b/image_ref/main_ray.py
@@ -43,9 +43,7 @@ def train_model(config,args):
     if torch.cuda.is_available():
         device = "cuda:0"
         if torch.cuda.device_count() > 1:
-            print(type(model))
             net = torch.nn.DataParallel(model)
-            print(type(net))
     model.to(device)
 
     if config['optimizer']=='Adam' :
@@ -54,7 +52,9 @@ def train_model(config,args):
         optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9)
     # init training
     n_class = len(data_train.dataset.classes)
-    loss_function = nn.CrossEntropyLoss(weight=torch.Tensor([1/n_class,1-1/n_class]))
+    weight = torch.Tensor([1/n_class,1-1/n_class])
+    weight.to(device)
+    loss_function = nn.CrossEntropyLoss(weight=weight)
 
     # Load existing checkpoint through `get_checkpoint()` API.
     if train.get_checkpoint():
diff --git a/main_ray.py b/main_ray.py
new file mode 100644
index 0000000..86cf087
--- /dev/null
+++ b/main_ray.py
@@ -0,0 +1,243 @@
+import os
+import tempfile
+
+from config.config import load_args
+from dataset.dataset import load_data_duo
+import torch
+import torch.nn as nn
+from models.model import Classification_model_duo
+import torch.optim as optim
+
+
+#ray
+from ray.air import RunConfig
+from ray.tune.search.optuna import OptunaSearch
+from ray import train, tune
+from ray.train import Checkpoint
+from ray.tune.schedulers import ASHAScheduler
+
+def train_model(config,args):
+    # load data
+
+
+    data_train, data_test = load_data_duo(dataset_dir=args.dataset_dir,
+                                                  batch_size=args.batch_size,
+                                                  noise_threshold=config['noise'],
+                                                  )
+
+    # load model
+    model = Classification_model_duo(model=args.model, n_class=len(data_train.dataset.classes))
+
+    # move parameters to GPU
+    model.double()
+    device = "cpu"
+    if torch.cuda.is_available():
+        device = "cuda:0"
+        if torch.cuda.device_count() > 1:
+            net = torch.nn.DataParallel(model)
+    model.to(device)
+
+    optimizer = optim.Adam(model.parameters(), lr=config["lr"],betas=(config['beta1'],config['beta2']))
+
+    # init training
+    if config['loss']=='base':
+        loss_function = nn.CrossEntropyLoss()
+    elif config['loss']== 'weighed':
+        classes_numbers = [51, 12, 9, 10, 86, 231, 20, 13, 24, 96, 11, 39, 11]
+        loss_weights = torch.tensor([1/n for n in classes_numbers])
+        loss_weights.to(device)
+        loss_function = nn.CrossEntropyLoss(loss_weights)
+    # Load existing checkpoint through `get_checkpoint()` API.
+    if train.get_checkpoint():
+        loaded_checkpoint = train.get_checkpoint()
+        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
+            model_state, optimizer_state = torch.load(
+                os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
+            )
+            net.load_state_dict(model_state)
+            optimizer.load_state_dict(optimizer_state)
+
+    # train model
+    for e in range(args.epoches):
+
+        #train loss
+        model.train()
+        losses = 0.
+        acc = 0.
+        for param in model.parameters():
+            param.requires_grad = True
+
+        for imaer, imana, label in data_train:
+            label = label.long()
+            if torch.cuda.is_available():
+                imaer = imaer.cuda()
+                imana = imana.cuda()
+                label = label.cuda()
+            pred_logits = model.forward(imaer, imana)
+            pred_class = torch.argmax(pred_logits, dim=1)
+            acc += (pred_class == label).sum().item()
+            loss = loss_function(pred_logits, label)
+            losses += loss.item()
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+        losses_train = losses / len(data_train.dataset)
+        acc_train = acc / len(data_train.dataset)
+
+        #validation loss
+        model.eval()
+        losses = 0.
+        acc = 0.
+        for param in model.parameters():
+            param.requires_grad = False
+
+        for imaer, imana, label in data_test:
+            label = label.long()
+            if torch.cuda.is_available():
+                imaer = imaer.cuda()
+                imana = imana.cuda()
+                label = label.cuda()
+            pred_logits = model.forward(imaer, imana)
+            pred_class = torch.argmax(pred_logits, dim=1)
+            acc += (pred_class == label).sum().item()
+            loss = loss_function(pred_logits, label)
+            losses += loss.item()
+        losses_val = losses / (label.shape[0] * len(data_test.dataset))
+        acc_val = acc / (len(data_test.dataset))
+
+        # Here we save a checkpoint. It is automatically registered with
+        # Ray Tune and will potentially be accessed through in ``get_checkpoint()``
+        # in future iterations.
+        # Note to save a file like checkpoint, you still need to put it under a directory
+        # to construct a checkpoint.
+        with tempfile.TemporaryDirectory(
+                dir='/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/checkpoints') as temp_checkpoint_dir:
+            path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
+
+            torch.save(
+                (model.state_dict(), optimizer.state_dict()), path
+            )
+            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
+            print(checkpoint.path)
+            train.report(
+                {"train loss": losses_train, "train cont acc": acc_train,"val loss": losses_val,"val acc": acc_val},
+                checkpoint=checkpoint,)
+    print("Finished Training")
+
+
+def test_model(best_result, args):
+
+    if best_result.config['res_count_thr']=='none':
+        ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref'
+    elif best_result.config['res_count_thr']=='10':
+        ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_10'
+    else :
+        ref_dir = '/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/img_ref_count_th_5'
+
+    # load data
+    _, data_test = load_data_duo(base_dir = args.dataset_dir,
+                                         noise_threshold=best_result.config['noise'])
+
+    # load model
+    model = Classification_model_duo(model=args.model, n_class=len(data_test.dataset.classes))
+    model.double()
+    # load weight
+    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
+
+    model_state, optimizer_state = torch.load(checkpoint_path)
+    model.load_state_dict(model_state)
+
+    # move parameters to GPU
+    device = "cpu"
+    if torch.cuda.is_available():
+        device = "cuda:0"
+        if torch.cuda.device_count() > 1:
+            print(type(model))
+            net = torch.nn.DataParallel(model)
+            print(type(net))
+    model.to(device)
+    # init training
+    loss_function = nn.CrossEntropyLoss()
+    model.eval()
+    losses = 0.
+    acc = 0.
+    for param in model.parameters():
+        param.requires_grad = False
+
+    for imaer, imana, label in data_test:
+        label = label.long()
+        if torch.cuda.is_available():
+            imaer = imaer.cuda()
+            imana = imana.cuda()
+            label = label.cuda()
+        pred_logits = model.forward(imaer, imana)
+        pred_class = torch.argmax(pred_logits, dim=1)
+        acc += (pred_class == label).sum().item()
+        loss = loss_function(pred_logits, label)
+        losses += loss.item()
+    losses_val = losses / (label.shape[0] * len(data_test.dataset))
+    acc_val = acc / (len(data_test.dataset))
+    print("Best trial test set AsyncHyperBandSchedulerloss: loss {}   acc {} ".format(losses_val,acc_val))
+
+def main(args, gpus_per_trial=1):
+
+
+    ''' Espace d'optimisation des HP :
+    lr : faire varier d'un facteur 10 au dessus en dessous : de 10-4 à 10-2
+beta1 : de 0.8 à 0.99
+beta2 : de 0.99 à 0.9999
+Le noise, de 0 à 10k
+essayer avec une loss pondérée ou pas pondérée
+
+    '''
+    config = {
+        "lr": tune.loguniform(1e-4, 1e-2),
+        "noise": tune.loguniform(1, 10000),
+        "beta1": tune.uniform(0.8, 0.99),
+        "beta2": tune.uniform(0.99, 0.9999),
+        "loss": tune.choice(['base', 'weighed']),
+    }
+    scheduler = ASHAScheduler(
+        max_t=100,
+        grace_period=3,
+        reduction_factor=3,
+        brackets=1,
+    )
+    algo = OptunaSearch()
+
+    tuner = tune.Tuner(
+        tune.with_resources(
+            tune.with_parameters(train_model, args=args),
+            resources={"cpu": 20, "gpu": gpus_per_trial}
+        ),
+        tune_config=tune.TuneConfig(
+            time_budget_s=3600 * 19.5,
+            search_alg=algo,
+            scheduler=scheduler,
+            num_samples=-1,
+            metric="val loss",
+            mode='min',
+
+        ),
+        run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/ray_results_test",
+                             name="base_experiment"
+                             ),
+        param_space=config
+
+    )
+    results = tuner.fit()
+
+    best_result = results.get_best_result("val loss", "min")
+
+    print("Best trial config: {}".format(best_result.config))
+    print("Best trial final validation loss: {}".format(
+        best_result.metrics["loss"]))
+    print("Best trial final validation accuracy: {}".format(
+        best_result.metrics["accuracy"]))
+
+    test_model(best_result, args)
+
+if __name__ == '__main__':
+    args = load_args()
+    print(args)
+    main(args)
-- 
GitLab