diff --git a/image_ref/main_sweep.py b/image_ref/main_sweep.py
new file mode 100644
index 0000000000000000000000000000000000000000..b136e4a83917aee146ccfea85fdb32d3130fe8fa
--- /dev/null
+++ b/image_ref/main_sweep.py
@@ -0,0 +1,31 @@
+
+import wandb as wdb
+from config import load_args_contrastive
+
+
+if __name__ == '__main__':
+    args = load_args_contrastive()
+    sweep_configuration = {
+        "program": "sweep_train.py",
+        "method": "random",
+        "metric": {"goal": "minimize", "name": "validation loss"},
+        "parameters": {
+            "epoches":{"value": 50},
+            "eval_inter":{"value": 1},
+            "noise_threshold": {"distribution" : "log_uniform_values", "max": 10000., "min": 0.0001},
+            "lr": {"distribution" : "log_uniform_values", "max": 0.01, "min": 0.0001},
+            "batch_size": {"value": 64},
+            "positive_prop": {"distribution" : "uniform","max": 95., "min": 5.},
+            "opti": {"value": "adam"},
+            "model": {"value": "resnet18"},
+            "sampler": {"values": ["random","balanced"]},
+            "dataset_train_dir": {"value": "data/processed_data_wiff/npy_image/train_data"},
+            "dataset_val_dir": {"value": "data/processed_data_wiff/npy_image/test_data"},
+            "dataset_ref_dir": {"values": ["image_ref/img_ref","image_ref/img_ref_count_th_10","image_ref/img_ref_count_th_5"]},
+        },
+    }
+    sweep_id = wdb.sweep(sweep=sweep_configuration, project="param_sweep_contrastive")
+
+    sweep = wdb.controller(sweep_id)
+    sweep.configure_controller(type="local")
+    sweep.run()
diff --git a/image_ref/sweep_train.py b/image_ref/sweep_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..056df14c20e5ee20512ab4165a15959539e3024a
--- /dev/null
+++ b/image_ref/sweep_train.py
@@ -0,0 +1,139 @@
+import os
+import wandb as wdb
+from dataset_ref import load_data_duo
+import torch
+import torch.nn as nn
+from model import Classification_model_duo_contrastive
+import torch.optim as optim
+
+def train_duo(model, data_train, optimizer, loss_function, epoch, wandb):
+    model.train()
+    losses = 0.
+    acc = 0.
+    for param in model.parameters():
+        param.requires_grad = True
+
+    for imaer, imana, img_ref, label in data_train:
+        imaer = imaer.float()
+        imana = imana.float()
+        img_ref = img_ref.float()
+        label = label.long()
+        if torch.cuda.is_available():
+            imaer = imaer.cuda()
+            imana = imana.cuda()
+            img_ref = img_ref.cuda()
+            label = label.cuda()
+        pred_logits = model.forward(imaer, imana, img_ref)
+        pred_class = torch.argmax(pred_logits, dim=1)
+        acc += (pred_class == label).sum().item()
+        loss = loss_function(pred_logits, label)
+        losses += loss.item()
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+    losses = losses / len(data_train.dataset)
+    acc = acc / len(data_train.dataset)
+    print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch, losses, acc))
+
+    wdb.log({"train loss": losses, 'train epoch': epoch, "train contrastive accuracy": acc})
+
+    return losses, acc
+
+
+def val_duo(model, data_test, loss_function, epoch, wandb):
+    model.eval()
+    losses = 0.
+    acc = 0.
+    acc_contrastive = 0.
+    for param in model.parameters():
+        param.requires_grad = False
+
+    for imaer, imana, img_ref, label in data_test:
+        imaer = imaer.float()
+        imana = imana.float()
+        img_ref = img_ref.float()
+        imaer = imaer.transpose(0, 1)
+        imana = imana.transpose(0, 1)
+        img_ref = img_ref.transpose(0, 1)
+        label = label.transpose(0, 1)
+        label = label.squeeze()
+        label = label.long()
+        if torch.cuda.is_available():
+            imaer = imaer.cuda()
+            imana = imana.cuda()
+            img_ref = img_ref.cuda()
+            label = label.cuda()
+        label_class = torch.argmin(label).data.cpu().numpy()
+        pred_logits = model.forward(imaer, imana, img_ref)
+        pred_class = torch.argmax(pred_logits[:, 0]).tolist()
+        acc_contrastive += (
+                    torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item()
+        acc += (pred_class == label_class)
+        loss = loss_function(pred_logits, label)
+        losses += loss.item()
+    losses = losses / (label.shape[0] * len(data_test.dataset))
+    acc = acc / (len(data_test.dataset))
+    acc_contrastive = acc_contrastive / (label.shape[0] * len(data_test.dataset))
+    print('Test epoch  {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch, losses, acc,
+                                                                                       acc_contrastive))
+
+    wdb.log({"validation loss": losses, 'validation epoch': epoch, "validation classification accuracy": acc,
+                 "validation contrastive accuracy": acc_contrastive})
+
+    return losses, acc, acc_contrastive
+
+
+def run_duo(args):
+    # wandb init
+    os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
+
+    os.environ["WANDB_MODE"] = "offline"
+    os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
+
+    wdb.init(project="param_sweep_contrastive", dir='./wandb_run')
+
+    print('Wandb initialised')
+    # load data
+    data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir,
+                                                                base_dir_val=args.dataset_val_dir,
+                                                                base_dir_test=None,
+                                                                batch_size=args.batch_size,
+                                                                ref_dir=args.dataset_ref_dir,
+                                                                positive_prop=args.positive_prop, sampler=args.sampler)
+
+    # load model
+    model = Classification_model_duo_contrastive(model=args.model, n_class=2)
+    model.float()
+    # move parameters to GPU
+    if torch.cuda.is_available():
+        print('Model loaded on GPU')
+        model = model.cuda()
+
+    # init accumulators
+    best_loss = 100
+    train_acc = []
+    train_loss = []
+    val_acc = []
+    val_cont_acc = []
+    val_loss = []
+    # init training
+    loss_function = nn.CrossEntropyLoss()
+    if args.opti == 'adam':
+        optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    # train model
+    for e in range(args.epoches):
+        loss, acc = train_duo(model, data_train, optimizer, loss_function, e, args.wandb)
+        train_loss.append(loss)
+        train_acc.append(acc)
+        if e % args.eval_inter == 0:
+            loss, acc, acc_contrastive = val_duo(model, data_val_batch, loss_function, e, args.wandb)
+            val_loss.append(loss)
+            val_acc.append(acc)
+            val_cont_acc.append(acc_contrastive)
+        wdb.finish()
+
+
+if __name__ == '__main__':
+    config = wdb.config
+    print(config)
+    run_duo(config)
\ No newline at end of file