From f3f1fff67cbdf57cb788c265fc89124aed684d44 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Fri, 4 Apr 2025 16:20:47 +0200 Subject: [PATCH] fix : default args --- image_ref/config.py | 2 +- image_ref/grad_cam.py | 7 ++++--- image_ref/main.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/image_ref/config.py b/image_ref/config.py index bbb13530..aac9449d 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -17,7 +17,7 @@ def load_args_contrastive(): parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') - parser.add_argument('--pretrain_path', type=str, default='saved_model/baseline_resnet18_contrastive_prop_30_bis.pt') + parser.add_argument('--pretrain_path', type=str, default=None) args = parser.parse_args() return args \ No newline at end of file diff --git a/image_ref/grad_cam.py b/image_ref/grad_cam.py index f8613582..091753e6 100644 --- a/image_ref/grad_cam.py +++ b/image_ref/grad_cam.py @@ -22,12 +22,13 @@ def compute_class_activation_map(): [transforms.Resize((224, 224)), transforms.Normalize(0.5, 0.5)]) + model_path = '../saved_model/baseline_resnet18_contrastive_prop_30_bis.pt' path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy' path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy' - path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive + # path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive # path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative - # path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative + path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative tensor_aer = npy_loader(path_aer) tensor_ana = npy_loader(path_ana) tensor_ref = npy_loader(path_ref) @@ -47,7 +48,7 @@ def compute_class_activation_map(): model.double() # load weight if args.pretrain_path is not None: - load_model(model, args.pretrain_path) + load_model(model, model_path) print('model loaded') # Identify the target layer diff --git a/image_ref/main.py b/image_ref/main.py index 017daeca..246df886 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -85,7 +85,7 @@ def run_duo(args): load_model(model,args.pretrain_path) #move parameters to GPU if torch.cuda.is_available(): - print('model loaded on GPU') + print('Model loaded on GPU') model = model.cuda() #init accumulators -- GitLab