diff --git a/image_ref/main.py b/image_ref/main.py index 62321f2899ab255ad6a7e648998888f03ac2c893..1d73a0d068424f7fea748d2b1d4144a2cd651d86 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -146,6 +146,7 @@ def make_prediction_duo(model, data, f_name, f_name2): confidence_pred_list = [[] for i in range(n_class)] y_pred = [] y_true = [] + soft_max = nn.Softmax(dim=1) # iterate over test data for imaer,imana,img_ref, label in data: imaer = imaer.transpose(0,1) @@ -163,7 +164,8 @@ def make_prediction_duo(model, data, f_name, f_name2): img_ref = img_ref.cuda() label = label.cuda() output = model(imaer,imana,img_ref) - confidence_pred_list[specie].append(output[:,0].data.cpu().numpy()) + confidence = soft_max(output) + confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy()) #Mono class output (only most postive paire) output = torch.argmax(output[:,0]) label = torch.argmin(label)