diff --git a/image_ref/config.py b/image_ref/config.py index 7a6ca0f915dbdc413a568b1527b028ed32d875c9..c679960fc07d12c8a04a0e5b9b5179ecdec90974 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -17,7 +17,7 @@ def load_args_contrastive(): parser.add_argument('--dataset_ref_dir', type=str, default='../image_ref/img_ref') parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') - parser.add_argument('--pretrain_path', type=str, default=None) + parser.add_argument('--pretrain_path', type=str, default='../output/best_model_constrastive.pt') args = parser.parse_args() return args \ No newline at end of file diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py index a086f2d316cae2ebdad6ad280a71d12b5d2c4d54..88913c55df2199b99f67c267e0e110183a3073a0 100644 --- a/image_ref/dataset_ref.py +++ b/image_ref/dataset_ref.py @@ -167,8 +167,13 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise transforms.Normalize(0.5, 0.5)]) print('Default val transform') - train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop) - val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir) + ref_transform = transforms.Compose( + [transforms.Resize((224, 224)), + transforms.Normalize(0.5, 0.5)]) + print('Default val transform') + + train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform) + val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform) data_loader_train = data.DataLoader( dataset=train_dataset, @@ -193,10 +198,11 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise class ImageFolderDuo_Batched(data.Dataset): def __init__(self, root, transform=None, target_transform=None, - flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None): + flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, ref_transform=None): self.root = root self.imlist = flist_reader(root) self.transform = transform + self.ref_transform = ref_transform self.target_transform = target_transform self.loader = loader self.classes = torchvision.datasets.folder.find_classes(root)[0] @@ -214,7 +220,7 @@ class ImageFolderDuo_Batched(data.Dataset): path_ref = self.ref_dir +'/'+ class_ref + '.npy' img_ref = self.loader(path_ref) if self.transform is not None: - img_ref = self.transform(img_ref) + img_ref = self.ref_transform(img_ref) img_refs.append(img_ref) label_refs.append(target_ref) if self.transform is not None: diff --git a/image_ref/grad_cam.py b/image_ref/grad_cam.py new file mode 100644 index 0000000000000000000000000000000000000000..57cd4e2fb91eaa22d0a0674ef53f8a0dbdddf7c4 --- /dev/null +++ b/image_ref/grad_cam.py @@ -0,0 +1,111 @@ +import numpy as np +import torch +import cv2 +from torchvision.transforms import transforms + +from image_ref.config import load_args_contrastive +from image_ref.dataset_ref import Threshold_noise, Log_normalisation, npy_loader +from image_ref.main import load_model +from image_ref.model import Classification_model_duo_contrastive + + +def compute_class_activation_map(): + args = load_args_contrastive() + + transform = transforms.Compose( + [transforms.Resize((224, 224)), + Threshold_noise(500), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + + ref_transform = transforms.Compose( + [transforms.Resize((224, 224)), + transforms.Normalize(0.5, 0.5)]) + + + path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy' + path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy' + path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' + + tensor_aer = npy_loader(path_aer) + tensor_ana = npy_loader(path_ana) + tensor_ref = npy_loader(path_ref) + + img_ref = np.load(path_ref) + + + tensor_aer = transform(tensor_aer) + tensor_ana = transform(tensor_ana) + tensor_ref = ref_transform(tensor_ref) + + tensor_aer = torch.unsqueeze(tensor_aer, dim=0) + tensor_ana = torch.unsqueeze(tensor_ana, dim=0) + tensor_ref = torch.unsqueeze(tensor_ref, dim=0) + + + model = Classification_model_duo_contrastive(model=args.model, n_class=2) + model.double() + # load weight + if args.pretrain_path is not None: + load_model(model, args.pretrain_path) + print('model loaded') + + # Identify the target layer + target_layer = model.im_encoder.layer4[-1] + + # Lists to store activations and gradients + activations = [] + gradients = [] + + # Hooks to capture activations and gradients + def forward_hook(module, input, output): + activations.append(output) + + def backward_hook(module, grad_input, grad_output): + gradients.append(grad_output[0]) + + target_layer.register_forward_hook(forward_hook) + target_layer.register_full_backward_hook(backward_hook) + + # Perform the forward pass + model.eval() # Set the model to evaluation mode + output = model(tensor_aer,tensor_ana,tensor_ref) + pred_class = output.argmax(dim=1).item() + + # Zero the gradients + model.zero_grad() + + # Backward pass to compute gradients + output[:, pred_class].backward() + + # Compute the weights + weights = torch.mean(gradients[0], dim=[2, 3]) + + # Compute the Grad-CAM heatmap + heatmap = torch.sum(weights.unsqueeze(dim=2).unsqueeze(dim=2) * activations[0], dim=1).squeeze() + heatmap = np.maximum(heatmap.cpu().detach().numpy(), 0) + heatmap /= np.max(heatmap) + + + + + # Resize the heatmap to match the original image size + heatmap = cv2.resize(heatmap, (img_ref.shape[1], img_ref.shape[0])) + + # Convert heatmap to RGB format and apply colormap + heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET) + + img_aer_rgb = cv2.applyColorMap(np.uint8(255 * img_ref), cv2.COLORMAP_JET) + + # Overlay the heatmap on the original image + superimposed_img = cv2.addWeighted(img_aer_rgb, 0.6, heatmap, 0.4, 0) + + # Display the result + cv2.imshow('Grad-CAM', superimposed_img) + cv2.waitKey(0) + cv2.destroyAllWindows() + + return heatmap + +if __name__ =='__main__': + compute_class_activation_map() \ No newline at end of file