From 74e2780680e61b9b15d7d663d9b24c3fbb2a537b Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 1 Apr 2025 15:13:36 +0200
Subject: [PATCH] fix : dataloader batched

---
 image_ref/main.py | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/image_ref/main.py b/image_ref/main.py
index e2dbdeb..9cd4c7b 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -110,22 +110,26 @@ def run_duo(args):
     plt.savefig('output/training_plot_contrastive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
     #load and evaluate best model
     load_model(model, args.save_path)
-    make_prediction_duo(model,data_test, 'output/confusion_matrix_contractive_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model))
+    make_prediction_duo(model,data_test, 'output/confusion_matrix_contractive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
 
 
 def make_prediction_duo(model, data, f_name):
+    n_class = len(data[0][2])
+    confidence_pred_list = [[] for i in range(n_class)]
     y_pred = []
     y_true = []
     # iterate over test data
     for imaer,imana,img_ref, label in data:
         label = label.long()
+        specie = torch.argmax(label)
+
         if torch.cuda.is_available():
             imaer = imaer.cuda()
             imana = imana.cuda()
             img_ref = img_ref.cuda()
             label = label.cuda()
         output = model(imaer,imana,img_ref)
-
+        confidence_pred_list[specie].append(output.data.cpu().numpy())
         output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
         y_pred.extend(output)
 
@@ -135,6 +139,10 @@ def make_prediction_duo(model, data, f_name):
 
     # Build confusion matrix
     cf_matrix = confusion_matrix(y_true, y_pred)
+    confidence_matrix = np.zeros((n_class,n_class))
+    for i in range(n_class):
+        confidence_matrix[i]=np.mean(confidence_pred_list[i],axis=0)
+
     df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in range(2)],
                          columns=['True','False'])
     print('Saving Confusion Matrix')
-- 
GitLab