diff --git a/image_ref/config.py b/image_ref/config.py index 8b38951a90fb158c810fdc03e0d95c77e17d65c6..d0d52db627854b4c2283a1db5f932c7e15314353 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -16,6 +16,7 @@ def load_args_contrastive(): parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive') parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive') parser.add_argument('--dataset_test_dir', type=str, default=None) + parser.add_argument('--base_out', type=str, default='output/baseline') parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') diff --git a/image_ref/main.py b/image_ref/main.py index 6cf959da2921624f146436fcb8a3706f3c84374a..a6b0988ffbd094513f4f645a15124eb842858385 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -174,13 +174,11 @@ def run_duo(args): # load and evaluate best model load_model(model, args.save_path) if args.args.dataset_test_dir is not None : - make_prediction_duo(model, data_test_batch, - 'output/confusion_matrix_contractive_{}_bis_test.png'.format(args.positive_prop), - 'output/confidence_matrix_contractive_{}_bis_test.png'.format(args.positive_prop)) + make_prediction_duo(model, data_test_batch,args.base_out+'_confusion_matrix_test.png', + args.base_out+'confidence_matrix_.png') - make_prediction_duo(model, data_val_batch, - 'output/confusion_matrix_contractive_{}_bis_val.png'.format(args.positive_prop), - 'output/confidence_matrix_contractive_{}_bis_val.png'.format(args.positive_prop)) + make_prediction_duo(model, data_val_batch,args.base_out+'_confusion_matrix_val.png', + args.base_out+'_confusion_matrix_val.png') if args.wandb is not None: wdb.finish()