diff --git a/image_ref/config.py b/image_ref/config.py
index 7a6ca0f915dbdc413a568b1527b028ed32d875c9..c679960fc07d12c8a04a0e5b9b5179ecdec90974 100644
--- a/image_ref/config.py
+++ b/image_ref/config.py
@@ -17,7 +17,7 @@ def load_args_contrastive():
     parser.add_argument('--dataset_ref_dir', type=str, default='../image_ref/img_ref')
     parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
     parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
-    parser.add_argument('--pretrain_path', type=str, default=None)
+    parser.add_argument('--pretrain_path', type=str, default='../output/best_model_constrastive.pt')
     args = parser.parse_args()
 
     return args
\ No newline at end of file
diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py
index a086f2d316cae2ebdad6ad280a71d12b5d2c4d54..88913c55df2199b99f67c267e0e110183a3073a0 100644
--- a/image_ref/dataset_ref.py
+++ b/image_ref/dataset_ref.py
@@ -167,8 +167,13 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
          transforms.Normalize(0.5, 0.5)])
     print('Default val transform')
 
-    train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop)
-    val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir)
+    ref_transform = transforms.Compose(
+        [transforms.Resize((224, 224)),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default val transform')
+
+    train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform)
+    val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform)
 
     data_loader_train = data.DataLoader(
         dataset=train_dataset,
@@ -193,10 +198,11 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
 
 class ImageFolderDuo_Batched(data.Dataset):
     def __init__(self, root, transform=None, target_transform=None,
-                 flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None):
+                 flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, ref_transform=None):
         self.root = root
         self.imlist = flist_reader(root)
         self.transform = transform
+        self.ref_transform = ref_transform
         self.target_transform = target_transform
         self.loader = loader
         self.classes = torchvision.datasets.folder.find_classes(root)[0]
@@ -214,7 +220,7 @@ class ImageFolderDuo_Batched(data.Dataset):
             path_ref = self.ref_dir +'/'+ class_ref + '.npy'
             img_ref = self.loader(path_ref)
             if self.transform is not None:
-                img_ref = self.transform(img_ref)
+                img_ref = self.ref_transform(img_ref)
             img_refs.append(img_ref)
             label_refs.append(target_ref)
         if self.transform is not None:
diff --git a/image_ref/grad_cam.py b/image_ref/grad_cam.py
new file mode 100644
index 0000000000000000000000000000000000000000..57cd4e2fb91eaa22d0a0674ef53f8a0dbdddf7c4
--- /dev/null
+++ b/image_ref/grad_cam.py
@@ -0,0 +1,111 @@
+import numpy as np
+import torch
+import cv2
+from torchvision.transforms import transforms
+
+from image_ref.config import load_args_contrastive
+from image_ref.dataset_ref import Threshold_noise, Log_normalisation, npy_loader
+from image_ref.main import load_model
+from image_ref.model import Classification_model_duo_contrastive
+
+
+def compute_class_activation_map():
+    args = load_args_contrastive()
+
+    transform = transforms.Compose(
+        [transforms.Resize((224, 224)),
+         Threshold_noise(500),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+
+    ref_transform = transforms.Compose(
+        [transforms.Resize((224, 224)),
+         transforms.Normalize(0.5, 0.5)])
+
+
+    path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy'
+    path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy'
+    path_ref ='../image_ref/img_ref/Citrobacter freundii.npy'
+
+    tensor_aer = npy_loader(path_aer)
+    tensor_ana = npy_loader(path_ana)
+    tensor_ref = npy_loader(path_ref)
+
+    img_ref = np.load(path_ref)
+
+
+    tensor_aer = transform(tensor_aer)
+    tensor_ana = transform(tensor_ana)
+    tensor_ref = ref_transform(tensor_ref)
+
+    tensor_aer = torch.unsqueeze(tensor_aer, dim=0)
+    tensor_ana = torch.unsqueeze(tensor_ana, dim=0)
+    tensor_ref = torch.unsqueeze(tensor_ref, dim=0)
+
+
+    model = Classification_model_duo_contrastive(model=args.model, n_class=2)
+    model.double()
+    # load weight
+    if args.pretrain_path is not None:
+        load_model(model, args.pretrain_path)
+        print('model loaded')
+
+    # Identify the target layer
+    target_layer = model.im_encoder.layer4[-1]
+
+    # Lists to store activations and gradients
+    activations = []
+    gradients = []
+
+    # Hooks to capture activations and gradients
+    def forward_hook(module, input, output):
+        activations.append(output)
+
+    def backward_hook(module, grad_input, grad_output):
+        gradients.append(grad_output[0])
+
+    target_layer.register_forward_hook(forward_hook)
+    target_layer.register_full_backward_hook(backward_hook)
+
+    # Perform the forward pass
+    model.eval()  # Set the model to evaluation mode
+    output = model(tensor_aer,tensor_ana,tensor_ref)
+    pred_class = output.argmax(dim=1).item()
+
+    # Zero the gradients
+    model.zero_grad()
+
+    # Backward pass to compute gradients
+    output[:, pred_class].backward()
+
+    # Compute the weights
+    weights = torch.mean(gradients[0], dim=[2, 3])
+
+    # Compute the Grad-CAM heatmap
+    heatmap = torch.sum(weights.unsqueeze(dim=2).unsqueeze(dim=2)  * activations[0], dim=1).squeeze()
+    heatmap = np.maximum(heatmap.cpu().detach().numpy(), 0)
+    heatmap /= np.max(heatmap)
+
+
+
+
+    # Resize the heatmap to match the original image size
+    heatmap = cv2.resize(heatmap, (img_ref.shape[1], img_ref.shape[0]))
+
+    # Convert heatmap to RGB format and apply colormap
+    heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
+
+    img_aer_rgb = cv2.applyColorMap(np.uint8(255 * img_ref), cv2.COLORMAP_JET)
+
+    # Overlay the heatmap on the original image
+    superimposed_img = cv2.addWeighted(img_aer_rgb, 0.6, heatmap, 0.4, 0)
+
+    # Display the result
+    cv2.imshow('Grad-CAM', superimposed_img)
+    cv2.waitKey(0)
+    cv2.destroyAllWindows()
+
+    return heatmap
+
+if __name__ =='__main__':
+    compute_class_activation_map()
\ No newline at end of file