Skip to content
Snippets Groups Projects
Commit 54489dd3 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : confidence matrix

parent 1e9fc7ac
No related branches found
No related tags found
No related merge requests found
...@@ -146,6 +146,7 @@ def make_prediction_duo(model, data, f_name, f_name2): ...@@ -146,6 +146,7 @@ def make_prediction_duo(model, data, f_name, f_name2):
confidence_pred_list = [[] for i in range(n_class)] confidence_pred_list = [[] for i in range(n_class)]
y_pred = [] y_pred = []
y_true = [] y_true = []
soft_max = nn.Softmax(dim=1)
# iterate over test data # iterate over test data
for imaer,imana,img_ref, label in data: for imaer,imana,img_ref, label in data:
imaer = imaer.transpose(0,1) imaer = imaer.transpose(0,1)
...@@ -163,7 +164,8 @@ def make_prediction_duo(model, data, f_name, f_name2): ...@@ -163,7 +164,8 @@ def make_prediction_duo(model, data, f_name, f_name2):
img_ref = img_ref.cuda() img_ref = img_ref.cuda()
label = label.cuda() label = label.cuda()
output = model(imaer,imana,img_ref) output = model(imaer,imana,img_ref)
confidence_pred_list[specie].append(output[:,0].data.cpu().numpy()) confidence = soft_max(output)
confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy())
#Mono class output (only most postive paire) #Mono class output (only most postive paire)
output = torch.argmax(output[:,0]) output = torch.argmax(output[:,0])
label = torch.argmin(label) label = torch.argmin(label)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment