Skip to content
Snippets Groups Projects
grad_cam.py 3.75 KiB
import numpy as np
import torch
import cv2
from torchvision.transforms import transforms
from image_ref.dataset_ref import Threshold_noise, Log_normalisation, npy_loader
from image_ref.main import load_model
from image_ref.model import Classification_model_duo_contrastive


def compute_class_activation_map(path_aer, path_ana, path_ref, model_path, model_type='ResNet18'):

    transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         Threshold_noise(0),
         Log_normalisation(),
         transforms.Normalize(0.5, 0.5)])

    ref_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         Log_normalisation(),
         transforms.Normalize(0.5, 0.5)])



    tensor_aer = npy_loader(path_aer)
    tensor_ana = npy_loader(path_ana)
    tensor_ref = npy_loader(path_ref)

    img_ref = np.load(path_ref)

    tensor_aer = transform(tensor_aer)
    tensor_ana = transform(tensor_ana)
    tensor_ref = ref_transform(tensor_ref)

    tensor_aer = torch.unsqueeze(tensor_aer, dim=0)
    tensor_ana = torch.unsqueeze(tensor_ana, dim=0)
    tensor_ref = torch.unsqueeze(tensor_ref, dim=0)


    model = Classification_model_duo_contrastive(model=model_type, n_class=2)
    model.double()
    # load weight
    load_model(model, model_path)
    print('model loaded')

    # Identify the target layer
    target_layer = model.im_encoder.layer3[-1]

    # Lists to store activations and gradients
    activations = []
    gradients = []

    # Hooks to capture activations and gradients
    def forward_hook(module, input, output):
        activations.append(output)

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    target_layer.register_forward_hook(forward_hook)
    target_layer.register_full_backward_hook(backward_hook)

    # Perform the forward pass
    model.eval()  # Set the model to evaluation mode
    output = model(tensor_aer,tensor_ana,tensor_ref)

    print(output)
    pred_class = output.argmax(dim=1).item()

    # Zero the gradients
    model.zero_grad()


    # Backward pass to compute gradients (finer Grad CAM)
    (output[:, pred_class]-output[:, 1-pred_class]).backward()
    print('Predicted class ',pred_class)

    # Compute the weights
    weights = torch.mean(gradients[0], dim=[2, 3])

    # Compute the Grad-CAM heatmap
    heatmap = torch.sum(weights.unsqueeze(dim=2).unsqueeze(dim=2)  * activations[0], dim=1).squeeze()
    heatmap = np.maximum(heatmap.cpu().detach().numpy(), 0)
    heatmap /= np.max(heatmap)




    # Resize the heatmap to match the original image size
    heatmap = cv2.resize(heatmap, (img_ref.shape[1], img_ref.shape[0]))

    # Convert heatmap to RGB format and apply colormap
    heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)

    img_aer_rgb = cv2.applyColorMap(np.uint8(255 * img_ref), cv2.COLORMAP_JET)

    # Overlay the heatmap on the original image
    superimposed_img = cv2.addWeighted(img_aer_rgb, 0.6, heatmap, 0.4, 0)

    # Display the result
    cv2.imshow('Grad-CAM', superimposed_img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

    return heatmap

if __name__ =='__main__':
    model_path = '../saved_model/baseline_resnet18_contrastive_prop_30_nref.pt'
    path_aer ='../data/processed_data/npy_image/data_test_contrastive/Klebsiella pneumoniae/KLEPNE23_AER.npy'
    path_ana ='../data/processed_data/npy_image/data_test_contrastive/Klebsiella pneumoniae/KLEPNE23_ANA.npy'
    # path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive
    # path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative
    path_ref = '../image_ref/img_ref/Proteus mirabilis.npy'  # negative
    # path_ref ='../image_ref/img_ref/Klebsiella pneumoniae.npy'

    compute_class_activation_map(path_aer, path_ana, path_ref, model_path)