Skip to content
Snippets Groups Projects
Commit f3f1fff6 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : default args

parent 80e17afa
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,7 @@ def load_args_contrastive(): ...@@ -17,7 +17,7 @@ def load_args_contrastive():
parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
parser.add_argument('--pretrain_path', type=str, default='saved_model/baseline_resnet18_contrastive_prop_30_bis.pt') parser.add_argument('--pretrain_path', type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
return args return args
\ No newline at end of file
...@@ -22,12 +22,13 @@ def compute_class_activation_map(): ...@@ -22,12 +22,13 @@ def compute_class_activation_map():
[transforms.Resize((224, 224)), [transforms.Resize((224, 224)),
transforms.Normalize(0.5, 0.5)]) transforms.Normalize(0.5, 0.5)])
model_path = '../saved_model/baseline_resnet18_contrastive_prop_30_bis.pt'
path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy' path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy'
path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy' path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy'
path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive # path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive
# path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative # path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative
# path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative
tensor_aer = npy_loader(path_aer) tensor_aer = npy_loader(path_aer)
tensor_ana = npy_loader(path_ana) tensor_ana = npy_loader(path_ana)
tensor_ref = npy_loader(path_ref) tensor_ref = npy_loader(path_ref)
...@@ -47,7 +48,7 @@ def compute_class_activation_map(): ...@@ -47,7 +48,7 @@ def compute_class_activation_map():
model.double() model.double()
# load weight # load weight
if args.pretrain_path is not None: if args.pretrain_path is not None:
load_model(model, args.pretrain_path) load_model(model, model_path)
print('model loaded') print('model loaded')
# Identify the target layer # Identify the target layer
......
...@@ -85,7 +85,7 @@ def run_duo(args): ...@@ -85,7 +85,7 @@ def run_duo(args):
load_model(model,args.pretrain_path) load_model(model,args.pretrain_path)
#move parameters to GPU #move parameters to GPU
if torch.cuda.is_available(): if torch.cuda.is_available():
print('model loaded on GPU') print('Model loaded on GPU')
model = model.cuda() model = model.cuda()
#init accumulators #init accumulators
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment