From dfa6f43af6552a318bb5619128870b9f49ea6688 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Mon, 23 Sep 2024 11:08:05 +0200
Subject: [PATCH] r2 scatter viz

---
 data_viz.py      | 11 ++++++++---
 requirements.txt | 21 ---------------------
 2 files changed, 8 insertions(+), 24 deletions(-)
 delete mode 100644 requirements.txt

diff --git a/data_viz.py b/data_viz.py
index ba16bc0..83d4d17 100644
--- a/data_viz.py
+++ b/data_viz.py
@@ -1,4 +1,5 @@
-
+import scipy as sp
+from sklearn.metrics import r2_score
 import matplotlib.pyplot as plt
 import numpy as np
 import random
@@ -141,6 +142,10 @@ def scatter_rt(dataframe, display=False, save=False, path=None):
     ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1)
     ax.set_xlabel('true RT')
     ax.set_ylabel('pred RT')
+    x = np.array([min(dataframe['true rt']), max(dataframe['true rt'])])
+    linreg = sp.stats.linregress(dataframe['true rt'], dataframe['rt pred'])
+    ax.annotate("r-squared = {:.3f}".format(r2_score(dataframe['true rt'], dataframe['rt pred'])), (0, 1))
+    plt.plot(x, linreg.intercept + linreg.slope * x, 'r')
     if display :
         plt.show()
 
@@ -233,5 +238,5 @@ df = pd.read_csv('output/out_prosit_common.csv')
 add_length(df)
 df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
 # histo_abs_error(df, display=False, save=True, path='temp.png')
-# scatter_rt(df, display=False, save=True, path='temp.png')
-histo_length_by_error(df, 10, save=True, path='temp.png')
\ No newline at end of file
+scatter_rt(df, display=False, save=True, path='RT_pred_prosit_prosit.png')
+# histo_length_by_error(df, 10, save=True, path='temp.png')
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 7d0c671..0000000
--- a/requirements.txt
+++ /dev/null
@@ -1,21 +0,0 @@
-torch~=2.1.2
-h5py~=3.10.0
-pandas~=2.2.0
-numpy~=1.26.2
-matplotlib~=3.8.2
-wandb~=0.16.2
-torchmetrics~=1.3.0.post0
-torcheval~=0.0.7
-seaborn~=0.13.0
-pyopenms~=3.1.0
-dlomix~=0.0.6
-scikit-learn~=1.4.1.post1
-Levenshtein~=0.25.0
-keras~=2.15.0
-tensorflow~=2.15.0.post1
-pillow~=10.3.0
-nibabel~=5.2.1
-nilearn~=0.10.4
-scipy~=1.12.0
-loess~=2.1.2
-ray~=2.20.0
\ No newline at end of file
-- 
GitLab