From c6a40b3a47da6de67a2a65074dbbc18415056ec9 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 22 Apr 2025 15:42:39 +0200
Subject: [PATCH] add : wandb confidence and confusion matrix

---
 image_ref/hyperparameter_res_analysis.py |  6 ++++++
 image_ref/main.py                        | 13 ++++++++++++-
 2 files changed, 18 insertions(+), 1 deletion(-)
 create mode 100644 image_ref/hyperparameter_res_analysis.py

diff --git a/image_ref/hyperparameter_res_analysis.py b/image_ref/hyperparameter_res_analysis.py
new file mode 100644
index 0000000..cdabac5
--- /dev/null
+++ b/image_ref/hyperparameter_res_analysis.py
@@ -0,0 +1,6 @@
+import pandas as pd
+import numpy as np
+
+df = pd.read_csv('../df_results_contrastive.csv')
+
+best_param = df[df['val loss']<0.003]
\ No newline at end of file
diff --git a/image_ref/main.py b/image_ref/main.py
index eaf030d..bab6e98 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -2,7 +2,7 @@ import os
 import wandb as wdb
 import matplotlib.pyplot as plt
 import numpy as np
-
+import PIL
 from config import load_args_contrastive
 from dataset_ref import load_data_duo
 import torch
@@ -180,6 +180,7 @@ def run_duo(args):
         plt.show()
         plt.savefig(args.base_out+'_training_plot.png')
 
+
     # load and evaluate best model
     load_model(model, args.save_path)
     if args.dataset_test_dir is not None :
@@ -190,6 +191,16 @@ def run_duo(args):
                         args.base_out+'_confidence_matrix_val.png')
 
     if args.wandb is not None:
+        if args.dataset_test_dir is not None:
+            wdb.log({
+                     'confidence matrix val' : wdb.Image(args.base_out+'_confidence_matrix_val.png'),
+                     'confidence matrix test' : wdb.Image(args.base_out+'_confidence_matrix_test.png'),
+                     'confusion matrix val' : wdb.Image(args.base_out+'_confusion_matrix_val.png'),
+                     'confusion matrix test' : wdb.Image(args.base_out+'_confusion_matrix_test.png')})
+        else :
+            wdb.log({
+                     'confidence matrix val': wdb.Image(args.base_out + '_confidence_matrix_val.png'),
+                     'confidence matrix test': wdb.Image(args.base_out + '_confidence_matrix_test.png'),})
         wdb.finish()
 
 
-- 
GitLab