From dfa6f43af6552a318bb5619128870b9f49ea6688 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 23 Sep 2024 11:08:05 +0200 Subject: [PATCH] r2 scatter viz --- data_viz.py | 11 ++++++++--- requirements.txt | 21 --------------------- 2 files changed, 8 insertions(+), 24 deletions(-) delete mode 100644 requirements.txt diff --git a/data_viz.py b/data_viz.py index ba16bc0..83d4d17 100644 --- a/data_viz.py +++ b/data_viz.py @@ -1,4 +1,5 @@ - +import scipy as sp +from sklearn.metrics import r2_score import matplotlib.pyplot as plt import numpy as np import random @@ -141,6 +142,10 @@ def scatter_rt(dataframe, display=False, save=False, path=None): ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1) ax.set_xlabel('true RT') ax.set_ylabel('pred RT') + x = np.array([min(dataframe['true rt']), max(dataframe['true rt'])]) + linreg = sp.stats.linregress(dataframe['true rt'], dataframe['rt pred']) + ax.annotate("r-squared = {:.3f}".format(r2_score(dataframe['true rt'], dataframe['rt pred'])), (0, 1)) + plt.plot(x, linreg.intercept + linreg.slope * x, 'r') if display : plt.show() @@ -233,5 +238,5 @@ df = pd.read_csv('output/out_prosit_common.csv') add_length(df) df['abs_error'] = np.abs(df['rt pred']-df['true rt']) # histo_abs_error(df, display=False, save=True, path='temp.png') -# scatter_rt(df, display=False, save=True, path='temp.png') -histo_length_by_error(df, 10, save=True, path='temp.png') \ No newline at end of file +scatter_rt(df, display=False, save=True, path='RT_pred_prosit_prosit.png') +# histo_length_by_error(df, 10, save=True, path='temp.png') \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 7d0c671..0000000 --- a/requirements.txt +++ /dev/null @@ -1,21 +0,0 @@ -torch~=2.1.2 -h5py~=3.10.0 -pandas~=2.2.0 -numpy~=1.26.2 -matplotlib~=3.8.2 -wandb~=0.16.2 -torchmetrics~=1.3.0.post0 -torcheval~=0.0.7 -seaborn~=0.13.0 -pyopenms~=3.1.0 -dlomix~=0.0.6 -scikit-learn~=1.4.1.post1 -Levenshtein~=0.25.0 -keras~=2.15.0 -tensorflow~=2.15.0.post1 -pillow~=10.3.0 -nibabel~=5.2.1 -nilearn~=0.10.4 -scipy~=1.12.0 -loess~=2.1.2 -ray~=2.20.0 \ No newline at end of file -- GitLab