Skip to content
Snippets Groups Projects
Commit b26847c5 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : load model test

parent 5c32d50e
No related branches found
No related tags found
No related merge requests found
...@@ -7,17 +7,17 @@ def load_args_contrastive(): ...@@ -7,17 +7,17 @@ def load_args_contrastive():
parser.add_argument('--epoches', type=int, default=0) parser.add_argument('--epoches', type=int, default=0)
parser.add_argument('--save_inter', type=int, default=50) parser.add_argument('--save_inter', type=int, default=50)
parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--noise_threshold', type=int, default=500)
parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--positive_prop', type=int, default=None) parser.add_argument('--positive_prop', type=int, default=None)
parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model', type=str, default='ResNet18')
parser.add_argument('--dataset_train_dir', type=str, default='../data/processed_data/npy_image/data_training_contrastive') parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive')
parser.add_argument('--dataset_val_dir', type=str, default='../data/processed_data/npy_image/data_test_contrastive') parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive')
parser.add_argument('--dataset_ref_dir', type=str, default='../image_ref/img_ref') parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
parser.add_argument('--output', type=str, default='../output/out_contrastive.csv') parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
parser.add_argument('--save_path', type=str, default='../output/best_model_constrastive.pt') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
parser.add_argument('--pretrain_path', type=str, default='../saved_model/baseline_resnet18_contrastive_prop_30.pt') parser.add_argument('--pretrain_path', type=str, default='saved_model/baseline_resnet18_contrastive_prop_30_bis.pt')
args = parser.parse_args() args = parser.parse_args()
return args return args
\ No newline at end of file
...@@ -26,8 +26,8 @@ def compute_class_activation_map(): ...@@ -26,8 +26,8 @@ def compute_class_activation_map():
path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy' path_aer ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_AER.npy'
path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy' path_ana ='../data/processed_data/npy_image/data_test_contrastive/Citrobacter freundii/CITFRE17_ANA.npy'
# path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive # path_ref ='../image_ref/img_ref/Citrobacter freundii.npy' #positive
path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative # path_ref = '../image_ref/img_ref/Enterobacter hormaechei.npy' #negative
# path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative path_ref = '../image_ref/img_ref/Proteus mirabilis.npy' # negative
tensor_aer = npy_loader(path_aer) tensor_aer = npy_loader(path_aer)
tensor_ana = npy_loader(path_ana) tensor_ana = npy_loader(path_ana)
tensor_ref = npy_loader(path_ref) tensor_ref = npy_loader(path_ref)
......
...@@ -81,7 +81,7 @@ def run_duo(args): ...@@ -81,7 +81,7 @@ def run_duo(args):
model.double() model.double()
#load weight #load weight
if args.pretrain_path is not None : if args.pretrain_path is not None :
'Model weight loaded' print('Model weight loaded')
load_model(model,args.pretrain_path) load_model(model,args.pretrain_path)
#move parameters to GPU #move parameters to GPU
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -168,7 +168,6 @@ def make_prediction_duo(model, data, f_name, f_name2): ...@@ -168,7 +168,6 @@ def make_prediction_duo(model, data, f_name, f_name2):
img_ref = img_ref.cuda() img_ref = img_ref.cuda()
label = label.cuda() label = label.cuda()
output = model(imaer,imana,img_ref) output = model(imaer,imana,img_ref)
print(output)
confidence = soft_max(output) confidence = soft_max(output)
confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy()) confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy())
#Mono class output (only most postive paire) #Mono class output (only most postive paire)
...@@ -213,4 +212,5 @@ def load_model(model, path): ...@@ -213,4 +212,5 @@ def load_model(model, path):
if __name__ == '__main__': if __name__ == '__main__':
args = load_args_contrastive() args = load_args_contrastive()
print(args)
run_duo(args) run_duo(args)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment