Skip to content
Snippets Groups Projects
Commit 10edb8ba authored by Schneider Leo's avatar Schneider Leo
Browse files

clean : grad_cam

fix : error in E.COLI f_name
parent f4753ec3
No related branches found
No related tags found
No related merge requests found
No preview for this file type
......@@ -2,15 +2,12 @@ import numpy as np
import torch
import cv2
from torchvision.transforms import transforms
from image_ref.config import load_args_contrastive
from image_ref.dataset_ref import Threshold_noise, Log_normalisation, npy_loader
from image_ref.main import load_model
from image_ref.model import Classification_model_duo_contrastive
def compute_class_activation_map():
args = load_args_contrastive()
def compute_class_activation_map(path_aer, path_ana, path_ref, model_path, model_type='Resnet18'):
transform = transforms.Compose(
[transforms.Resize((224, 224)),
......@@ -22,13 +19,8 @@ def compute_class_activation_map():
[transforms.Resize((224, 224)),
transforms.Normalize(0.5, 0.5)])
model_path = '../saved_model/baseline_resnet18_contrastive_prop_30_bis.pt'
path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy'
path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy'
# path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive
# path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative
path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative
tensor_aer = npy_loader(path_aer)
tensor_ana = npy_loader(path_ana)
tensor_ref = npy_loader(path_ref)
......@@ -44,12 +36,11 @@ def compute_class_activation_map():
tensor_ref = torch.unsqueeze(tensor_ref, dim=0)
model = Classification_model_duo_contrastive(model=args.model, n_class=2)
model = Classification_model_duo_contrastive(model=model_type, n_class=2)
model.double()
# load weight
if args.pretrain_path is not None:
load_model(model, model_path)
print('model loaded')
load_model(model, model_path)
print('model loaded')
# Identify the target layer
target_layer = model.im_encoder.layer4[-1]
......@@ -112,24 +103,11 @@ def compute_class_activation_map():
return heatmap
if __name__ =='__main__':
# compute_class_activation_map()
transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(500),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
ref_transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(0),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)
])
path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' # negative
tensor_ref = npy_loader(path_ref)
model_path = '../saved_model/baseline_resnet18_contrastive_prop_30_bis.pt'
path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy'
path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy'
# path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive
# path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative
path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative
ref_base = tensor_ref.squeeze()
ref_false = transform(tensor_ref).squeeze()
ref_true = ref_transform(tensor_ref).squeeze()
\ No newline at end of file
compute_class_activation_map(path_aer, path_ana, path_ref, model_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment