Skip to content
Snippets Groups Projects
Commit 78499056 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : load model test

parent 64369620
No related branches found
No related tags found
No related merge requests found
...@@ -4,20 +4,20 @@ import argparse ...@@ -4,20 +4,20 @@ import argparse
def load_args_contrastive(): def load_args_contrastive():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--epoches', type=int, default=1) parser.add_argument('--epoches', type=int, default=0)
parser.add_argument('--save_inter', type=int, default=50) parser.add_argument('--save_inter', type=int, default=50)
parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--noise_threshold', type=int, default=0)
parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--positive_prop', type=int, default=None) parser.add_argument('--positive_prop', type=int, default=None)
parser.add_argument('--model', type=str, default='ResNet50') parser.add_argument('--model', type=str, default='ResNet18')
parser.add_argument('--dataset_train_dir', type=str, default='../data/processed_data/npy_image/data_training_contrastive') parser.add_argument('--dataset_train_dir', type=str, default='../data/processed_data/npy_image/data_training_contrastive')
parser.add_argument('--dataset_val_dir', type=str, default='../data/processed_data/npy_image/data_test_contrastive') parser.add_argument('--dataset_val_dir', type=str, default='../data/processed_data/npy_image/data_test_contrastive')
parser.add_argument('--dataset_ref_dir', type=str, default='../image_ref/img_ref') parser.add_argument('--dataset_ref_dir', type=str, default='../image_ref/img_ref')
parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--output', type=str, default='../output/out_contrastive.csv')
parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') parser.add_argument('--save_path', type=str, default='../output/best_model_constrastive.pt')
parser.add_argument('--pretrain_path', type=str, default='../output/best_model_constrastive.pt') parser.add_argument('--pretrain_path', type=str, default='../saved_model/baseline_resnet18_contrastive_prop_30.pt')
args = parser.parse_args() args = parser.parse_args()
return args return args
\ No newline at end of file
...@@ -114,11 +114,12 @@ def make_dataset_custom( ...@@ -114,11 +114,12 @@ def make_dataset_custom(
class ImageFolderDuo(data.Dataset): class ImageFolderDuo(data.Dataset):
def __init__(self, root, transform=None, target_transform=None, def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, positive_prop=None): flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, positive_prop=None, ref_transform=None):
self.root = root self.root = root
self.imlist = flist_reader(root) self.imlist = flist_reader(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.ref_transform = ref_transform
self.loader = loader self.loader = loader
self.classes = torchvision.datasets.folder.find_classes(root)[0] self.classes = torchvision.datasets.folder.find_classes(root)[0]
self.ref_dir = ref_dir self.ref_dir = ref_dir
...@@ -144,7 +145,7 @@ class ImageFolderDuo(data.Dataset): ...@@ -144,7 +145,7 @@ class ImageFolderDuo(data.Dataset):
if self.transform is not None: if self.transform is not None:
imgAER = self.transform(imgAER) imgAER = self.transform(imgAER)
imgANA = self.transform(imgANA) imgANA = self.transform(imgANA)
img_ref = self.transform(img_ref) img_ref = self.ref_transform(img_ref)
contrastive_target = 0 if target == label_ref else 1 contrastive_target = 0 if target == label_ref else 1
return imgAER, imgANA, img_ref, contrastive_target return imgAER, imgANA, img_ref, contrastive_target
...@@ -169,6 +170,8 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise ...@@ -169,6 +170,8 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
ref_transform = transforms.Compose( ref_transform = transforms.Compose(
[transforms.Resize((224, 224)), [transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)]) transforms.Normalize(0.5, 0.5)])
print('Default val transform') print('Default val transform')
......
...@@ -25,18 +25,18 @@ def compute_class_activation_map(): ...@@ -25,18 +25,18 @@ def compute_class_activation_map():
path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy' path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy'
path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy' path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy'
path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' # path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive
path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative
# path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative
tensor_aer = npy_loader(path_aer) tensor_aer = npy_loader(path_aer)
tensor_ana = npy_loader(path_ana) tensor_ana = npy_loader(path_ana)
tensor_ref = npy_loader(path_ref) tensor_ref = npy_loader(path_ref)
img_ref = np.load(path_ref) img_ref = np.load(path_ref)
tensor_aer = transform(tensor_aer) tensor_aer = transform(tensor_aer)
tensor_ana = transform(tensor_ana) tensor_ana = transform(tensor_ana)
tensor_ref = ref_transform(tensor_ref) tensor_ref = transform(tensor_ref)
tensor_aer = torch.unsqueeze(tensor_aer, dim=0) tensor_aer = torch.unsqueeze(tensor_aer, dim=0)
tensor_ana = torch.unsqueeze(tensor_ana, dim=0) tensor_ana = torch.unsqueeze(tensor_ana, dim=0)
...@@ -70,6 +70,8 @@ def compute_class_activation_map(): ...@@ -70,6 +70,8 @@ def compute_class_activation_map():
# Perform the forward pass # Perform the forward pass
model.eval() # Set the model to evaluation mode model.eval() # Set the model to evaluation mode
output = model(tensor_aer,tensor_ana,tensor_ref) output = model(tensor_aer,tensor_ana,tensor_ref)
print(output)
pred_class = output.argmax(dim=1).item() pred_class = output.argmax(dim=1).item()
# Zero the gradients # Zero the gradients
...@@ -77,6 +79,7 @@ def compute_class_activation_map(): ...@@ -77,6 +79,7 @@ def compute_class_activation_map():
# Backward pass to compute gradients # Backward pass to compute gradients
output[:, pred_class].backward() output[:, pred_class].backward()
print('Predicted class ',pred_class)
# Compute the weights # Compute the weights
weights = torch.mean(gradients[0], dim=[2, 3]) weights = torch.mean(gradients[0], dim=[2, 3])
......
...@@ -81,6 +81,7 @@ def run_duo(args): ...@@ -81,6 +81,7 @@ def run_duo(args):
model.double() model.double()
#load weight #load weight
if args.pretrain_path is not None : if args.pretrain_path is not None :
'Model weight loaded'
load_model(model,args.pretrain_path) load_model(model,args.pretrain_path)
#move parameters to GPU #move parameters to GPU
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -134,12 +135,12 @@ def run_duo(args): ...@@ -134,12 +135,12 @@ def run_duo(args):
plt.show() plt.show()
plt.savefig('output/training_plot_contrastive_{}.png'.format(args.positive_prop)) plt.savefig('../output/training_plot_contrastive_{}.png'.format(args.positive_prop))
#load and evaluate best model #load and evaluate best model
load_model(model, args.save_path) load_model(model, args.save_path)
make_prediction_duo(model,data_test_batch, 'output/confusion_matrix_contractive_{}_bis.png'.format(args.positive_prop), make_prediction_duo(model,data_test_batch, '../output/confusion_matrix_contractive_{}_bis.png'.format(args.positive_prop),
'output/confidence_matrix_contractive_{}_bis.png'.format(args.positive_prop)) '../output/confidence_matrix_contractive_{}_bis.png'.format(args.positive_prop))
def make_prediction_duo(model, data, f_name, f_name2): def make_prediction_duo(model, data, f_name, f_name2):
...@@ -167,6 +168,7 @@ def make_prediction_duo(model, data, f_name, f_name2): ...@@ -167,6 +168,7 @@ def make_prediction_duo(model, data, f_name, f_name2):
img_ref = img_ref.cuda() img_ref = img_ref.cuda()
label = label.cuda() label = label.cuda()
output = model(imaer,imana,img_ref) output = model(imaer,imana,img_ref)
print(output)
confidence = soft_max(output) confidence = soft_max(output)
confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy()) confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy())
#Mono class output (only most postive paire) #Mono class output (only most postive paire)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment