From b26847c5f645a1abeb2d62074f71724ad129037a Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Fri, 4 Apr 2025 15:32:37 +0200 Subject: [PATCH] fix : load model test --- image_ref/config.py | 16 ++++++++-------- image_ref/grad_cam.py | 4 ++-- image_ref/main.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/image_ref/config.py b/image_ref/config.py index 7ca9999..eae5984 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -7,17 +7,17 @@ def load_args_contrastive(): parser.add_argument('--epoches', type=int, default=0) parser.add_argument('--save_inter', type=int, default=50) parser.add_argument('--eval_inter', type=int, default=1) - parser.add_argument('--noise_threshold', type=int, default=0) + parser.add_argument('--noise_threshold', type=int, default=500) parser.add_argument('--lr', type=float, default=0.001) - parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--positive_prop', type=int, default=None) parser.add_argument('--model', type=str, default='ResNet18') - parser.add_argument('--dataset_train_dir', type=str, default='../data/processed_data/npy_image/data_training_contrastive') - parser.add_argument('--dataset_val_dir', type=str, default='../data/processed_data/npy_image/data_test_contrastive') - parser.add_argument('--dataset_ref_dir', type=str, default='../image_ref/img_ref') - parser.add_argument('--output', type=str, default='../output/out_contrastive.csv') - parser.add_argument('--save_path', type=str, default='../output/best_model_constrastive.pt') - parser.add_argument('--pretrain_path', type=str, default='../saved_model/baseline_resnet18_contrastive_prop_30.pt') + parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive') + parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive') + parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') + parser.add_argument('--output', type=str, default='output/out_contrastive.csv') + parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') + parser.add_argument('--pretrain_path', type=str, default='saved_model/baseline_resnet18_contrastive_prop_30_bis.pt') args = parser.parse_args() return args \ No newline at end of file diff --git a/image_ref/grad_cam.py b/image_ref/grad_cam.py index a2de3c3..303f313 100644 --- a/image_ref/grad_cam.py +++ b/image_ref/grad_cam.py @@ -26,8 +26,8 @@ def compute_class_activation_map(): path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy' path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy' # path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive - path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative - # path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative + # path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative + path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative tensor_aer = npy_loader(path_aer) tensor_ana = npy_loader(path_ana) tensor_ref = npy_loader(path_ref) diff --git a/image_ref/main.py b/image_ref/main.py index 2a6bdc0..8e33af8 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -81,7 +81,7 @@ def run_duo(args): model.double() #load weight if args.pretrain_path is not None : - 'Model weight loaded' + print('Model weight loaded') load_model(model,args.pretrain_path) #move parameters to GPU if torch.cuda.is_available(): @@ -168,7 +168,6 @@ def make_prediction_duo(model, data, f_name, f_name2): img_ref = img_ref.cuda() label = label.cuda() output = model(imaer,imana,img_ref) - print(output) confidence = soft_max(output) confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy()) #Mono class output (only most postive paire) @@ -213,4 +212,5 @@ def load_model(model, path): if __name__ == '__main__': args = load_args_contrastive() + print(args) run_duo(args) \ No newline at end of file -- GitLab