Skip to content
Snippets Groups Projects
Commit 9f968c03 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : grad CAM

parent d0df642b
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ def load_args_contrastive():
parser.add_argument('--dataset_ref_dir', type=str, default='../image_ref/img_ref')
parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
parser.add_argument('--pretrain_path', type=str, default=None)
parser.add_argument('--pretrain_path', type=str, default='../output/best_model_constrastive.pt')
args = parser.parse_args()
return args
\ No newline at end of file
......@@ -167,8 +167,13 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
transforms.Normalize(0.5, 0.5)])
print('Default val transform')
train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop)
val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir)
ref_transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.Normalize(0.5, 0.5)])
print('Default val transform')
train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform)
val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform)
data_loader_train = data.DataLoader(
dataset=train_dataset,
......@@ -193,10 +198,11 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
class ImageFolderDuo_Batched(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None):
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, ref_transform=None):
self.root = root
self.imlist = flist_reader(root)
self.transform = transform
self.ref_transform = ref_transform
self.target_transform = target_transform
self.loader = loader
self.classes = torchvision.datasets.folder.find_classes(root)[0]
......@@ -214,7 +220,7 @@ class ImageFolderDuo_Batched(data.Dataset):
path_ref = self.ref_dir +'/'+ class_ref + '.npy'
img_ref = self.loader(path_ref)
if self.transform is not None:
img_ref = self.transform(img_ref)
img_ref = self.ref_transform(img_ref)
img_refs.append(img_ref)
label_refs.append(target_ref)
if self.transform is not None:
......
import numpy as np
import torch
import cv2
from torchvision.transforms import transforms
from image_ref.config import load_args_contrastive
from image_ref.dataset_ref import Threshold_noise, Log_normalisation, npy_loader
from image_ref.main import load_model
from image_ref.model import Classification_model_duo_contrastive
def compute_class_activation_map():
args = load_args_contrastive()
transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(500),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
ref_transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.Normalize(0.5, 0.5)])
path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy'
path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy'
path_ref ='../image_ref/img_ref/Citrobacter freundii.npy'
tensor_aer = npy_loader(path_aer)
tensor_ana = npy_loader(path_ana)
tensor_ref = npy_loader(path_ref)
img_ref = np.load(path_ref)
tensor_aer = transform(tensor_aer)
tensor_ana = transform(tensor_ana)
tensor_ref = ref_transform(tensor_ref)
tensor_aer = torch.unsqueeze(tensor_aer, dim=0)
tensor_ana = torch.unsqueeze(tensor_ana, dim=0)
tensor_ref = torch.unsqueeze(tensor_ref, dim=0)
model = Classification_model_duo_contrastive(model=args.model, n_class=2)
model.double()
# load weight
if args.pretrain_path is not None:
load_model(model, args.pretrain_path)
print('model loaded')
# Identify the target layer
target_layer = model.im_encoder.layer4[-1]
# Lists to store activations and gradients
activations = []
gradients = []
# Hooks to capture activations and gradients
def forward_hook(module, input, output):
activations.append(output)
def backward_hook(module, grad_input, grad_output):
gradients.append(grad_output[0])
target_layer.register_forward_hook(forward_hook)
target_layer.register_full_backward_hook(backward_hook)
# Perform the forward pass
model.eval() # Set the model to evaluation mode
output = model(tensor_aer,tensor_ana,tensor_ref)
pred_class = output.argmax(dim=1).item()
# Zero the gradients
model.zero_grad()
# Backward pass to compute gradients
output[:, pred_class].backward()
# Compute the weights
weights = torch.mean(gradients[0], dim=[2, 3])
# Compute the Grad-CAM heatmap
heatmap = torch.sum(weights.unsqueeze(dim=2).unsqueeze(dim=2) * activations[0], dim=1).squeeze()
heatmap = np.maximum(heatmap.cpu().detach().numpy(), 0)
heatmap /= np.max(heatmap)
# Resize the heatmap to match the original image size
heatmap = cv2.resize(heatmap, (img_ref.shape[1], img_ref.shape[0]))
# Convert heatmap to RGB format and apply colormap
heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
img_aer_rgb = cv2.applyColorMap(np.uint8(255 * img_ref), cv2.COLORMAP_JET)
# Overlay the heatmap on the original image
superimposed_img = cv2.addWeighted(img_aer_rgb, 0.6, heatmap, 0.4, 0)
# Display the result
cv2.imshow('Grad-CAM', superimposed_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
return heatmap
if __name__ =='__main__':
compute_class_activation_map()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment