diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py
new file mode 100644
index 0000000000000000000000000000000000000000..c66087cb3ab2c67e121054838556b0d57253129b
--- /dev/null
+++ b/image_ref/main_ray.py
@@ -0,0 +1,263 @@
+import os
+import tempfile
+
+from config import load_args_contrastive
+from dataset_ref import load_data_duo
+import torch
+import torch.nn as nn
+from model import Classification_model_duo_contrastive
+import torch.optim as optim
+
+
+#ray
+from ray.air import RunConfig
+from ray.tune.search.optuna import OptunaSearch
+from ray import train, tune
+from ray.train import Checkpoint
+from ray.tune.schedulers import ASHAScheduler
+
+def train_model(config,args):
+    # load data
+    data_train, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir,
+                                                  base_dir_val=args.dataset_val_dir,
+                                                  base_dir_test=args.dataset_test_dir,
+                                                  batch_size=args.batch_size,
+                                                  ref_dir=args.dataset_ref_dir,
+                                                  noise_threshold=config['noise'],
+                                                  positive_prop=config['positive_prop'], sampler=config['sampler'])
+
+    # load model
+    model = Classification_model_duo_contrastive(model=args.model, n_class=2)
+
+    # move parameters to GPU
+    model.double()
+    device = "cpu"
+    if torch.cuda.is_available():
+        device = "cuda:0"
+        if torch.cuda.device_count() > 1:
+            print(type(model))
+            net = torch.nn.DataParallel(model)
+            print(type(net))
+    model.to(device)
+
+    if config['optimizer']=='Adam' :
+        optimizer = optim.Adam(model.parameters(), lr=config["lr"])
+    elif config['optimizer']=='SGD' :
+        optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9)
+    # init training
+    loss_function = nn.CrossEntropyLoss()
+
+    # Load existing checkpoint through `get_checkpoint()` API.
+    if train.get_checkpoint():
+        loaded_checkpoint = train.get_checkpoint()
+        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
+            model_state, optimizer_state = torch.load(
+                os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
+            )
+            net.load_state_dict(model_state)
+            optimizer.load_state_dict(optimizer_state)
+
+    # train model
+    for e in range(args.epoches):
+
+        #train loss
+        model.train()
+        losses = 0.
+        acc = 0.
+        for param in model.parameters():
+            param.requires_grad = True
+
+        for imaer, imana, img_ref, label in data_train:
+            label = label.long()
+            if torch.cuda.is_available():
+                imaer = imaer.cuda()
+                imana = imana.cuda()
+                img_ref = img_ref.cuda()
+                label = label.cuda()
+            if torch.cuda.device_count() > 1:
+                pred_logits = model.module.forward(imaer, imana, img_ref)
+            else:
+                pred_logits = model.forward(imaer, imana, img_ref)
+
+
+            pred_class = torch.argmax(pred_logits, dim=1)
+            acc += (pred_class == label).sum().item()
+            loss = loss_function(pred_logits, label)
+            losses += loss.item()
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+        losses_train = losses / len(data_train.dataset)
+        acc_train = acc / len(data_train.dataset)
+
+        #validation loss
+        model.eval()
+        losses = 0.
+        acc = 0.
+        acc_contrastive = 0.
+        for param in model.parameters():
+            param.requires_grad = False
+
+        for imaer, imana, img_ref, label in data_val_batch:
+            imaer = imaer.transpose(0, 1)
+            imana = imana.transpose(0, 1)
+            img_ref = img_ref.transpose(0, 1)
+            label = label.transpose(0, 1)
+            label = label.squeeze()
+            label = label.long()
+            if torch.cuda.is_available():
+                imaer = imaer.cuda()
+                imana = imana.cuda()
+                img_ref = img_ref.cuda()
+                label = label.cuda()
+            label_class = torch.argmin(label).data.cpu().numpy()
+
+            if torch.cuda.device_count() > 1:
+                pred_logits = model.module.forward(imaer, imana, img_ref)
+            else:
+                pred_logits = model.forward(imaer, imana, img_ref)
+
+            pred_class = torch.argmax(pred_logits[:, 0]).tolist()
+            acc_contrastive += (
+                    torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item()
+            acc += (pred_class == label_class)
+            loss = loss_function(pred_logits, label)
+            losses += loss.item()
+        losses_val = losses / (label.shape[0] * len(data_val_batch.dataset))
+        acc_val = acc / (len(data_val_batch.dataset))
+        acc_contrastive_val = acc_contrastive / (label.shape[0] * len(data_val_batch.dataset))
+
+        # Here we save a checkpoint. It is automatically registered with
+        # Ray Tune and will potentially be accessed through in ``get_checkpoint()``
+        # in future iterations.
+        # Note to save a file like checkpoint, you still need to put it under a directory
+        # to construct a checkpoint.
+        with tempfile.TemporaryDirectory(
+                dir='lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/checkpoints') as temp_checkpoint_dir:
+            path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
+
+            torch.save(
+                (model.state_dict(), optimizer.state_dict()), path
+            )
+            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
+            print(checkpoint.path)
+            train.report(
+                {"train loss": losses_train, "train contrastive acc": acc_train,"val loss": losses_val,"val acc": acc_val,"val contrastive acc": acc_contrastive_val,},
+                checkpoint=checkpoint,)
+    print("Finished Training")
+
+
+def test_model(best_result, args):
+    # load data
+    _, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir,
+                                         base_dir_val=args.dataset_val_dir,
+                                         base_dir_test=args.dataset_test_dir,
+                                         batch_size=args.batch_size,
+                                         ref_dir=args.dataset_ref_dir,
+                                         noise_threshold=best_result.config['noise'],
+                                         positive_prop=best_result.config['positive_prop'], sampler=best_result.config['sampler'])
+
+    # load model
+    model = Classification_model_duo_contrastive(model=args.model, n_class=2)
+    model.double()
+    # load weight
+    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
+
+    model_state, optimizer_state = torch.load(checkpoint_path)
+    model.load_state_dict(model_state)
+
+    # move parameters to GPU
+    device = "cpu"
+    if torch.cuda.is_available():
+        device = "cuda:0"
+        if torch.cuda.device_count() > 1:
+            print(type(model))
+            net = torch.nn.DataParallel(model)
+            print(type(net))
+    model.to(device)
+    # init training
+    loss_function = nn.CrossEntropyLoss()
+    model.eval()
+    losses = 0.
+    acc = 0.
+    acc_contrastive = 0.
+    for param in model.parameters():
+        param.requires_grad = False
+
+    for imaer, imana, img_ref, label in data_val_batch:
+        imaer = imaer.transpose(0, 1)
+        imana = imana.transpose(0, 1)
+        img_ref = img_ref.transpose(0, 1)
+        label = label.transpose(0, 1)
+        label = label.squeeze()
+        label = label.long()
+        if torch.cuda.is_available():
+            imaer = imaer.cuda()
+            imana = imana.cuda()
+            img_ref = img_ref.cuda()
+            label = label.cuda()
+        label_class = torch.argmin(label).data.cpu().numpy()
+        pred_logits = model.forward(imaer, imana, img_ref)
+        pred_class = torch.argmax(pred_logits[:, 0]).tolist()
+        acc_contrastive += (
+                torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item()
+        acc += (pred_class == label_class)
+        loss = loss_function(pred_logits, label)
+        losses += loss.item()
+    losses = losses / (label.shape[0] * len(data_val_batch.dataset))
+    acc = acc / (len(data_val_batch.dataset))
+    acc_contrastive = acc_contrastive / (label.shape[0] * len(data_val_batch.dataset))
+    print("Best trial test set AsyncHyperBandSchedulerloss: loss {}   acc {}   acc_contrastive {}".format(losses,acc,acc_contrastive))
+
+def main(args, gpus_per_trial=1):
+    config = {
+        "lr": tune.loguniform(1e-4, 1e-2),
+        "noise": tune.loguniform(0, 500),
+        "positive_prop": tune.uniform(0, 100),
+        "optimizer": tune.choice(['Adam', 'SGD']),
+        "sampler": tune.choice(['random', 'balanced']),
+    }
+    scheduler = ASHAScheduler(
+        max_t=100,
+        grace_period=20,
+        reduction_factor=3,
+        brackets=1,
+    )
+    algo = OptunaSearch()
+
+    tuner = tune.Tuner(
+        tune.with_resources(
+            tune.with_parameters(train_model, args=args),
+            resources={"cpu": 80, "gpu": gpus_per_trial}
+        ),
+        tune_config=tune.TuneConfig(
+            time_budget_s=3600 * 23.5,
+            search_alg=algo,
+            scheduler=scheduler,
+            num_samples=50,
+            metric="val loss",
+            mode='min',
+
+        ),
+        run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/ray_results_test",
+                             name="test_experiment_no_scheduler"
+                             ),
+        param_space=config
+
+    )
+    results = tuner.fit()
+
+    best_result = results.get_best_result("val loss", "min")
+
+    print("Best trial config: {}".format(best_result.config))
+    print("Best trial final validation loss: {}".format(
+        best_result.metrics["loss"]))
+    print("Best trial final validation accuracy: {}".format(
+        best_result.metrics["accuracy"]))
+
+    test_model(best_result, args)
+
+if __name__ == '__main__':
+    args = load_args_contrastive()
+    print(args)
+    main(args)