Skip to content
Snippets Groups Projects
Commit dfa6f43a authored by Schneider Leo's avatar Schneider Leo
Browse files

r2 scatter viz

parent ac013450
Branches Maries-branch
No related tags found
No related merge requests found
import scipy as sp
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import random import random
...@@ -141,6 +142,10 @@ def scatter_rt(dataframe, display=False, save=False, path=None): ...@@ -141,6 +142,10 @@ def scatter_rt(dataframe, display=False, save=False, path=None):
ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1) ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1)
ax.set_xlabel('true RT') ax.set_xlabel('true RT')
ax.set_ylabel('pred RT') ax.set_ylabel('pred RT')
x = np.array([min(dataframe['true rt']), max(dataframe['true rt'])])
linreg = sp.stats.linregress(dataframe['true rt'], dataframe['rt pred'])
ax.annotate("r-squared = {:.3f}".format(r2_score(dataframe['true rt'], dataframe['rt pred'])), (0, 1))
plt.plot(x, linreg.intercept + linreg.slope * x, 'r')
if display : if display :
plt.show() plt.show()
...@@ -233,5 +238,5 @@ df = pd.read_csv('output/out_prosit_common.csv') ...@@ -233,5 +238,5 @@ df = pd.read_csv('output/out_prosit_common.csv')
add_length(df) add_length(df)
df['abs_error'] = np.abs(df['rt pred']-df['true rt']) df['abs_error'] = np.abs(df['rt pred']-df['true rt'])
# histo_abs_error(df, display=False, save=True, path='temp.png') # histo_abs_error(df, display=False, save=True, path='temp.png')
# scatter_rt(df, display=False, save=True, path='temp.png') scatter_rt(df, display=False, save=True, path='RT_pred_prosit_prosit.png')
histo_length_by_error(df, 10, save=True, path='temp.png') # histo_length_by_error(df, 10, save=True, path='temp.png')
\ No newline at end of file \ No newline at end of file
torch~=2.1.2
h5py~=3.10.0
pandas~=2.2.0
numpy~=1.26.2
matplotlib~=3.8.2
wandb~=0.16.2
torchmetrics~=1.3.0.post0
torcheval~=0.0.7
seaborn~=0.13.0
pyopenms~=3.1.0
dlomix~=0.0.6
scikit-learn~=1.4.1.post1
Levenshtein~=0.25.0
keras~=2.15.0
tensorflow~=2.15.0.post1
pillow~=10.3.0
nibabel~=5.2.1
nilearn~=0.10.4
scipy~=1.12.0
loess~=2.1.2
ray~=2.20.0
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment