Skip to content
Snippets Groups Projects
Commit 21015384 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : print confidence matrix construction

parent 92781b5d
No related branches found
No related tags found
No related merge requests found
...@@ -228,6 +228,7 @@ def make_prediction_duo(model, data, f_name, f_name2): ...@@ -228,6 +228,7 @@ def make_prediction_duo(model, data, f_name, f_name2):
label = label.long() label = label.long()
specie = torch.argmin(label) specie = torch.argmin(label)
if torch.cuda.is_available(): if torch.cuda.is_available():
imaer = imaer.cuda() imaer = imaer.cuda()
imana = imana.cuda() imana = imana.cuda()
...@@ -235,6 +236,8 @@ def make_prediction_duo(model, data, f_name, f_name2): ...@@ -235,6 +236,8 @@ def make_prediction_duo(model, data, f_name, f_name2):
label = label.cuda() label = label.cuda()
output = model(imaer, imana, img_ref) output = model(imaer, imana, img_ref)
confidence = soft_max(output) confidence = soft_max(output)
print(label)
print(confidence)
confidence_pred_list[specie].append(confidence[:, 0].data.cpu().numpy()) confidence_pred_list[specie].append(confidence[:, 0].data.cpu().numpy())
# Mono class output (only most postive paire) # Mono class output (only most postive paire)
output = torch.argmax(output[:, 0]) output = torch.argmax(output[:, 0])
...@@ -248,8 +251,10 @@ def make_prediction_duo(model, data, f_name, f_name2): ...@@ -248,8 +251,10 @@ def make_prediction_duo(model, data, f_name, f_name2):
cf_matrix = confusion_matrix(y_true, y_pred) cf_matrix = confusion_matrix(y_true, y_pred)
confidence_matrix = np.zeros((n_class, n_class)) confidence_matrix = np.zeros((n_class, n_class))
for i in range(n_class): for i in range(n_class):
print('species ',classes[i],' nb sample test : ',len(confidence_pred_list[i]))
confidence_matrix[i] = np.mean(confidence_pred_list[i], axis=0) confidence_matrix[i] = np.mean(confidence_pred_list[i], axis=0)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
columns=[i for i in classes]) columns=[i for i in classes])
print('Saving Confusion Matrix') print('Saving Confusion Matrix')
...@@ -267,6 +272,8 @@ def make_prediction_duo(model, data, f_name, f_name2): ...@@ -267,6 +272,8 @@ def make_prediction_duo(model, data, f_name, f_name2):
plt.savefig(f_name2) plt.savefig(f_name2)
def save_model(model, path): def save_model(model, path):
print('Model saved') print('Model saved')
torch.save(model.state_dict(), path) torch.save(model.state_dict(), path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment