-
Schneider Leo authored5433e00f
grad_cam.py 3.75 KiB
import numpy as np
import torch
import cv2
from torchvision.transforms import transforms
from image_ref.dataset_ref import Threshold_noise, Log_normalisation, npy_loader
from image_ref.main import load_model
from image_ref.model import Classification_model_duo_contrastive
def compute_class_activation_map(path_aer, path_ana, path_ref, model_path, model_type='ResNet18'):
transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(0),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
ref_transform = transforms.Compose(
[transforms.Resize((224, 224)),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
tensor_aer = npy_loader(path_aer)
tensor_ana = npy_loader(path_ana)
tensor_ref = npy_loader(path_ref)
img_ref = np.load(path_ref)
tensor_aer = transform(tensor_aer)
tensor_ana = transform(tensor_ana)
tensor_ref = ref_transform(tensor_ref)
tensor_aer = torch.unsqueeze(tensor_aer, dim=0)
tensor_ana = torch.unsqueeze(tensor_ana, dim=0)
tensor_ref = torch.unsqueeze(tensor_ref, dim=0)
model = Classification_model_duo_contrastive(model=model_type, n_class=2)
model.double()
# load weight
load_model(model, model_path)
print('model loaded')
# Identify the target layer
target_layer = model.im_encoder.layer3[-1]
# Lists to store activations and gradients
activations = []
gradients = []
# Hooks to capture activations and gradients
def forward_hook(module, input, output):
activations.append(output)
def backward_hook(module, grad_input, grad_output):
gradients.append(grad_output[0])
target_layer.register_forward_hook(forward_hook)
target_layer.register_full_backward_hook(backward_hook)
# Perform the forward pass
model.eval() # Set the model to evaluation mode
output = model(tensor_aer,tensor_ana,tensor_ref)
print(output)
pred_class = output.argmax(dim=1).item()
# Zero the gradients
model.zero_grad()
# Backward pass to compute gradients (finer Grad CAM)
(output[:, pred_class]-output[:, 1-pred_class]).backward()
print('Predicted class ',pred_class)
# Compute the weights
weights = torch.mean(gradients[0], dim=[2, 3])
# Compute the Grad-CAM heatmap
heatmap = torch.sum(weights.unsqueeze(dim=2).unsqueeze(dim=2) * activations[0], dim=1).squeeze()
heatmap = np.maximum(heatmap.cpu().detach().numpy(), 0)
heatmap /= np.max(heatmap)
# Resize the heatmap to match the original image size
heatmap = cv2.resize(heatmap, (img_ref.shape[1], img_ref.shape[0]))
# Convert heatmap to RGB format and apply colormap
heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
img_aer_rgb = cv2.applyColorMap(np.uint8(255 * img_ref), cv2.COLORMAP_JET)
# Overlay the heatmap on the original image
superimposed_img = cv2.addWeighted(img_aer_rgb, 0.6, heatmap, 0.4, 0)
# Display the result
cv2.imshow('Grad-CAM', superimposed_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
return heatmap
if __name__ =='__main__':
model_path = '../saved_model/baseline_resnet18_contrastive_prop_30_nref.pt'
path_aer ='../data/processed_data/npy_image/data_test_contrastive/Klebsiella pneumoniae/KLEPNE23_AER.npy'
path_ana ='../data/processed_data/npy_image/data_test_contrastive/Klebsiella pneumoniae/KLEPNE23_ANA.npy'
# path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive
# path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative
path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative
# path_ref ='../image_ref/img_ref/Klebsiella pneumoniae.npy'
compute_class_activation_map(path_aer, path_ana, path_ref, model_path)