diff --git a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC13_AER.npy b/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC13_AER.npy
deleted file mode 100644
index 79ae914db62e0a8431735a93009f531d0df607b5..0000000000000000000000000000000000000000
Binary files a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC13_AER.npy and /dev/null differ
diff --git a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC13_ANA.npy b/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC13_ANA.npy
deleted file mode 100644
index 079544a37e1214911184e21240b89f3819a67406..0000000000000000000000000000000000000000
Binary files a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC13_ANA.npy and /dev/null differ
diff --git a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC1_AER.npy b/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC1_AER.npy
deleted file mode 100644
index 3fbb8e88073eaaf531b2a588621b2bd8df58666c..0000000000000000000000000000000000000000
Binary files a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC1_AER.npy and /dev/null differ
diff --git a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC1_ANA.npy b/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC1_ANA.npy
deleted file mode 100644
index 4cbf8aa3badd150f70d1dd81fa224f36fa29cffe..0000000000000000000000000000000000000000
Binary files a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC1_ANA.npy and /dev/null differ
diff --git a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC3_AER.npy b/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC3_AER.npy
deleted file mode 100644
index 7cec8452920708899ade593be6d3525aa580299a..0000000000000000000000000000000000000000
Binary files a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC3_AER.npy and /dev/null differ
diff --git a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC3_ANA.npy b/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC3_ANA.npy
deleted file mode 100644
index 29bdf625a01fb42e1688be4411eec280ba0485bf..0000000000000000000000000000000000000000
Binary files a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC3_ANA.npy and /dev/null differ
diff --git a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC4_AER.npy b/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC4_AER.npy
deleted file mode 100644
index 18903cb618a16de6c8434565137e05ac520e3c55..0000000000000000000000000000000000000000
Binary files a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC4_AER.npy and /dev/null differ
diff --git a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC4_ANA.npy b/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC4_ANA.npy
deleted file mode 100644
index edbdf2d38faea592c3e129804d760b5b16bd481a..0000000000000000000000000000000000000000
Binary files a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC4_ANA.npy and /dev/null differ
diff --git a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC7_AER.npy b/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC7_AER.npy
deleted file mode 100644
index 92d2f234c6b9e1b945b678cc4e89e31ba4e5d297..0000000000000000000000000000000000000000
Binary files a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC7_AER.npy and /dev/null differ
diff --git a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC7_ANA.npy b/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC7_ANA.npy
deleted file mode 100644
index b22901839ad6e0cd10eeb1d37b93ce4e7d7563b0..0000000000000000000000000000000000000000
Binary files a/data/processed_data/npy_image/data_training_contrastive/Klebsiella michiganensis/KLEMIC7_ANA.npy and /dev/null differ
diff --git a/image_ref/config.py b/image_ref/config.py
index 26c1bdfe544fc78f6309a2908b77fe43ba54611e..a63c21726c266d0ed7f134a8952f3d62bf4a5213 100644
--- a/image_ref/config.py
+++ b/image_ref/config.py
@@ -12,7 +12,6 @@ def load_args_contrastive():
     parser.add_argument('--batch_size', type=int, default=64)
     parser.add_argument('--positive_prop', type=int, default=None)
     parser.add_argument('--model', type=str, default='ResNet18')
-    parser.add_argument('--model_type', type=str, default='duo')
     parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive')
     parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
     parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py
index be91a4cf616adb2a6ff7ab6621e377ddb8e1a7df..5dd8ff3040094f74f290ddd695b26cc77cda0c2e 100644
--- a/image_ref/dataset_ref.py
+++ b/image_ref/dataset_ref.py
@@ -192,5 +192,86 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0, ref_dir
 
     return data_loader_train, data_loader_test
 
-def load_data():
-    raise 'Not implemented'
\ No newline at end of file
+
+class ImageFolderDuo_Batched(data.Dataset):
+    def __init__(self, root, transform=None, target_transform=None,
+                 flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None):
+        self.root = root
+        self.imlist = flist_reader(root)
+        self.transform = transform
+        self.target_transform = target_transform
+        self.loader = loader
+        self.classes = torchvision.datasets.folder.find_classes(root)[0]
+        self.ref_dir = ref_dir
+
+    def __getitem__(self, index):
+        impathAER, impathANA, target = self.imlist[index]
+        imgAER = self.loader(impathAER)
+        imgANA = self.loader(impathANA)
+        img_refs = []
+        label_refs = []
+        for ind_ref in range(len(self.classes)):
+            class_ref = self.classes[ind_ref]
+            target_ref = 0 if target == ind_ref else 1
+            path_ref = self.ref_dir +'/'+ class_ref + '.npy'
+            img_ref = self.loader(path_ref)
+            if self.transform is not None:
+                img_ref = self.transform(img_ref)
+            img_refs.append(img_ref)
+            label_refs.append(target_ref)
+        if self.transform is not None:
+            imgAER = self.transform(imgAER)
+            imgANA = self.transform(imgANA)
+
+        batched_im_ref = torch.concat(img_refs,dim=0)
+        batched_label = torch.tensor(label_refs)
+        batched_imgAER = imgAER.Tensor.repeat(len(self.classes))
+        batched_imgANA = imgANA.Tensor.repeat(len(self.classes))
+
+        return batched_imgAER, batched_imgANA, batched_im_ref, batched_label
+
+    def __len__(self):
+        return len(self.imlist)
+
+def load_data_duo_batched(base_dir, shuffle=True, noise_threshold=0, ref_dir = None):
+    train_transform = transforms.Compose(
+        [transforms.Resize((224, 224)),
+         Threshold_noise(noise_threshold),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default train transform')
+
+    val_transform = transforms.Compose(
+        [transforms.Resize((224, 224)),
+         Threshold_noise(noise_threshold),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default val transform')
+
+    train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform, ref_dir = ref_dir)
+    val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform, ref_dir = ref_dir)
+    generator1 = torch.Generator().manual_seed(42)
+    indices = torch.randperm(len(train_dataset), generator=generator1)
+    val_size = len(train_dataset) // 5
+    train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
+    val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
+
+    data_loader_train = data.DataLoader(
+        dataset=train_dataset,
+        batch_size=1,
+        shuffle=shuffle,
+        num_workers=0,
+        collate_fn=None,
+        pin_memory=False,
+    )
+
+    data_loader_test = data.DataLoader(
+        dataset=val_dataset,
+        batch_size=1,
+        shuffle=shuffle,
+        num_workers=0,
+        collate_fn=None,
+        pin_memory=False,
+    )
+
+    return data_loader_train, data_loader_test
diff --git a/image_ref/main.py b/image_ref/main.py
index d65d0c8e7f31d2259271177123f609a225b6461e..ff5690919a3142d717eb3b3603f3987b3fd39c39 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 
 from config import load_args_contrastive
-from dataset_ref import load_data, load_data_duo
+from dataset_ref import load_data_duo_batched, load_data_duo
 import torch
 import torch.nn as nn
 from model import Classification_model_contrastive, Classification_model_duo_contrastive
@@ -11,129 +11,6 @@ from sklearn.metrics import confusion_matrix
 import seaborn as sn
 import pandas as pd
 
-
-
-def train(model, data_train, optimizer, loss_function, epoch):
-    model.train()
-    losses = 0.
-    acc = 0.
-    for param in model.parameters():
-        param.requires_grad = True
-
-    for im, label in data_train:
-        label = label.long()
-        if torch.cuda.is_available():
-            im, label = im.cuda(), label.cuda()
-        pred_logits = model.forward(im)
-        pred_class = torch.argmax(pred_logits,dim=1)
-        acc += (pred_class==label).sum().item()
-        loss = loss_function(pred_logits,label)
-        losses += loss.item()
-        optimizer.zero_grad()
-        loss.backward()
-        optimizer.step()
-    losses = losses/len(data_train.dataset)
-    acc = acc/len(data_train.dataset)
-    print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
-    return losses, acc
-
-def test(model, data_test, loss_function, epoch):
-    model.eval()
-    losses = 0.
-    acc = 0.
-    for param in model.parameters():
-        param.requires_grad = False
-
-    for im, label in data_test:
-        label = label.long()
-        if torch.cuda.is_available():
-            im, label = im.cuda(), label.cuda()
-        pred_logits = model.forward(im)
-        pred_class = torch.argmax(pred_logits,dim=1)
-        acc += (pred_class==label).sum().item()
-        loss = loss_function(pred_logits,label)
-        losses += loss.item()
-    losses = losses/len(data_test.dataset)
-    acc = acc/len(data_test.dataset)
-    print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
-    return losses,acc
-
-def run(args):
-    #load data
-    data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size)
-    #load model
-    model = Classification_model_contrastive(model = args.model, n_class=2,
-                                             ref_dir = args.dataset_ref_dir)
-    #load weights
-    if args.pretrain_path is not None :
-        load_model(model,args.pretrain_path)
-    #move parameters to GPU
-    if torch.cuda.is_available():
-        model = model.cuda()
-    #init accumulator
-    best_acc = 0
-    train_acc=[]
-    train_loss=[]
-    val_acc=[]
-    val_loss=[]
-    #init training
-    loss_function = nn.CrossEntropyLoss()
-    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
-    #traing
-    for e in range(args.epoches):
-        loss, acc = train(model,data_train,optimizer,loss_function,e)
-        train_loss.append(loss)
-        train_acc.append(acc)
-        if e%args.eval_inter==0 :
-            loss, acc = test(model,data_test,loss_function,e)
-            val_loss.append(loss)
-            val_acc.append(acc)
-            if acc > best_acc :
-                save_model(model,args.save_path)
-                best_acc = acc
-    #plot and save training figs
-    plt.plot(train_acc)
-    plt.plot(val_acc)
-    plt.plot(train_acc)
-    plt.plot(train_acc)
-    plt.ylim(0, 1.05)
-    plt.show()
-    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
-
-    #load and evaluated best model
-    load_model(model, args.save_path)
-    make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
-
-
-def make_prediction(model, data, f_name):
-    y_pred = []
-    y_true = []
-
-    # iterate over test data
-    for im, label in data:
-        label = label.long()
-        if torch.cuda.is_available():
-            im = im.cuda()
-        output = model(im)
-
-        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
-        y_pred.extend(output)
-
-        label = label.data.cpu().numpy()
-        y_true.extend(label)  # Save Truth
-    # constant for classes
-
-    classes = data.dataset.dataset.classes
-
-    # Build confusion matrix
-    cf_matrix = confusion_matrix(y_true, y_pred)
-    df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
-                         columns=[i for i in classes])
-    plt.figure(figsize=(12, 7))
-    sn.heatmap(df_cm, annot=cf_matrix)
-    plt.savefig(f_name)
-
-
 def train_duo(model, data_train, optimizer, loss_function, epoch):
     model.train()
     losses = 0.
@@ -189,6 +66,8 @@ def run_duo(args):
     #load data
     data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size,
                                           ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop)
+    data_train_batch, data_test_batch = load_data_duo_batched(base_dir=args.dataset_dir,
+                                                              ref_dir=args.dataset_ref_dir)
     #load model
     model = Classification_model_duo_contrastive(model = args.model, n_class=2)
     model.double()
@@ -214,7 +93,7 @@ def run_duo(args):
         train_loss.append(loss)
         train_acc.append(acc)
         if e%args.eval_inter==0 :
-            loss, acc = test_duo(model,data_test,loss_function,e)
+            loss, acc = test_duo(model,data_test_batch,loss_function,e)
             val_loss.append(loss)
             val_acc.append(acc)
             if acc > best_acc :
@@ -275,7 +154,4 @@ def load_model(model, path):
 
 if __name__ == '__main__':
     args = load_args_contrastive()
-    if args.model_type=='duo':
-        run_duo(args)
-    else :
-        run(args)
\ No newline at end of file
+    run_duo(args)
diff --git a/image_ref/utils.py b/image_ref/utils.py
index 093ac7c5328f7550c99308ad6d5cdc97128cfdde..02eb8fd11461ca9f9aecf90c8cbde9b194f359c0 100644
--- a/image_ref/utils.py
+++ b/image_ref/utils.py
@@ -213,7 +213,7 @@ def build_ref_image(path_fasta, possible_charge, ms1_end_mz, ms1_start_mz, bin_m
 
     return im
 
-
+'Klebsiella michiganensis',
 def build_ref_image_from_diann(path_parqet, ms1_end_mz, ms1_start_mz, bin_mz, max_cycle, min_rt=None, max_rt=None):
     df = load_lib(path_parqet)
     df = df[['Stripped.Sequence', 'Precursor.Charge', 'RT', 'Precursor.Mz']]
@@ -237,27 +237,28 @@ def build_ref_image_from_diann(path_parqet, ms1_end_mz, ms1_start_mz, bin_mz, ma
 
 if __name__ == '__main__':
     df = build_database_ref_peptide()
-    for spe in SPECIES:
-
-        df_spe = df[df['Specie']==spe]
-        spe_list = df_spe['Sequence'].to_list()
-        with open('fasta/optimal peptide set/'+spe+'.fasta',"w") as f:
-            if not spe_list:
-                print('Empty : ',spe)
-            for pep in spe_list :
-                f.write(pep)
-
-    #
+    #Write fasta file
+    # for spe in SPECIES:
     #
-    # df_full = load_lib(
-    #     'fasta/full proteom/steigerwaltii variants/uniparc_proteome_UP000033376_2025_03_14.predicted.parquet')
-    # min_rt = df_full['RT'].min()
-    # max_rt = df_full['RT'].max()
+    #     df_spe = df[df['Specie']==spe]
+    #     spe_list = df_spe['Sequence'].to_list()
+    #     with open('fasta/optimal peptide set/'+spe+'.fasta',"w") as f:
+    #         if not spe_list:
+    #             print('Empty : ',spe)
+    #         for pep in spe_list :
+    #             f.write(pep)
+
     #
-    # for spe in SPECIES:
-    #     im = build_ref_image_from_diann(
-    #         'fasta/optimal peptide set/' + spe + '.parquet', ms1_end_mz=1250,
-    #         ms1_start_mz=350, bin_mz=1, max_cycle=663, min_rt=min_rt, max_rt=max_rt)
-    #     plt.clf()
-    #     mpimg.imsave('img_ref/' + spe + '.png', im)
-    #     np.save(spe + '.npy', im)
+    #Create ref img
+    df_full = load_lib(
+        'fasta/full proteom/steigerwaltii variants/uniparc_proteome_UP000033376_2025_03_14.predicted.parquet')
+    min_rt = df_full['RT'].min()
+    max_rt = df_full['RT'].max()
+
+    for spe in SPECIES:
+        im = build_ref_image_from_diann(
+            'fasta/optimal peptide set/' + spe + '.parquet', ms1_end_mz=1250,
+            ms1_start_mz=350, bin_mz=1, max_cycle=663, min_rt=min_rt, max_rt=max_rt)
+        plt.clf()
+        mpimg.imsave('img_ref/' + spe + '.png', im)
+        np.save(spe + '.npy', im)