Skip to content
Snippets Groups Projects
Commit 3ccde8cd authored by Schneider Leo's avatar Schneider Leo
Browse files

model.eval() and dataviz

parent e54a7822
No related branches found
No related tags found
No related merge requests found
import scipy as sp
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
import matplotlib.colors as mcolors
def histo_abs_error(dataframe, display=False, save=False, path=None):
points = dataframe['abs_error']
## combine these different collections into a list
data_to_plot = [points]
# Create a figure instance
fig, ax = plt.subplots()
# Create the boxplot
ax.set_xlabel('abs error')
ax.violinplot(data_to_plot, vert=False, side='high', showmedians=True, quantiles=[0.95])
ax.set_xlim(0,4)
if display :
plt.show()
if save :
plt.savefig(path)
def random_color_deterministic(df, column):
def rd10(str):
color = list(mcolors.CSS4_COLORS)
random.seed(str)
return color[random.randint(0,147)]
df['color']=df[column].map(rd10)
def scatter_rt(dataframe, display=False, save=False, path=None, color = False, col = 'seq'):
fig, ax = plt.subplots()
if color :
random_color_deterministic(dataframe, col)
ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1, color = dataframe['color'])
else :
ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1)
ax.set_xlabel('true RT')
ax.set_ylabel('pred RT')
x = np.array([min(dataframe['true rt']), max(dataframe['true rt'])])
linreg = sp.stats.linregress(dataframe['true rt'], dataframe['rt pred'])
ax.annotate("r-squared = {:.3f}".format(r2_score(dataframe['true rt'], dataframe['rt pred'])), (0, 1))
plt.plot(x, linreg.intercept + linreg.slope * x, 'r')
if display :
plt.show()
if save :
plt.savefig(path)
def histo_abs_error_by_length(dataframe, display=False, save=False, path=None):
data_to_plot =[]
max_length = max(dataframe['length'])
min_length = min(dataframe['length'])
for l in range(min_length, max_length):
data_to_plot.append(dataframe['abs_error'].where(dataframe['length']==l))
# data_to_plot.append()
fig, ax = plt.subplots()
# Create the boxplot
bp = ax.violinplot(data_to_plot, vert=True, side='low')
if display:
plt.show()
if save:
plt.savefig(path)
def running_mean(x, N):
cumsum = np.cumsum(np.insert(x, 0, 0))
return (cumsum[N:] - cumsum[:-N]) / float(N)
def histo_length_by_error(dataframe, bins, display=False, save=False, path=None):
data_to_plot = []
quanti = []
max_error = max(dataframe['abs_error'])
inter = np.linspace(0, max_error, num=bins+1)
inter_m = running_mean(inter, 2)
inter_labels = list(map(lambda x : round(x,2),inter_m))
inter_labels.insert(0,0)
for i in range(bins):
a = dataframe.loc[(inter[i] < dataframe['abs_error']) & (dataframe['abs_error'] < inter[i+1])]['length']
if len(a)>0:
data_to_plot.append(a)
quanti.append(0.95)
else :
data_to_plot.append([0])
quanti.append(0.95)
fig, ax = plt.subplots()
# Create the boxplot
ax.violinplot(data_to_plot, vert=True, side='high', showmedians=True)
ax.set_ylabel('length')
ax.set_xticks(range(len(inter)),inter_labels)
if display:
plt.show()
if save:
plt.savefig(path)
def compare_error(df1, df2, display=False, save=False, path=None):
df1['abs err 1'] = df1['rt pred'] - df1['true rt']
df2['abs err 2'] = df2['rt pred'] - df2['true rt']
df_group_1 = df1.groupby(['seq'])['abs err 1'].mean().to_frame().reset_index()
df_group_2 = df2.groupby(['seq'])['abs err 2'].mean().to_frame().reset_index()
df = pd.concat([df_group_1,df_group_2],axis=1)
fig, ax = plt.subplots()
ax.scatter(df['abs err 1'], df['abs err 2'], s=0.1, alpha=0.05)
plt.savefig('temp.png')
if display:
plt.show()
if save:
plt.savefig(path)
def select_best_data(df1,df2,threshold):
df1['abs err 1'] = abs(df1['rt pred'] - df1['true rt'])
df2['abs err 2'] = abs(df2['rt pred'] - df2['true rt'])
df_group_1 = df1.groupby(['seq'])['abs err 1'].mean().to_frame().reset_index()
df_group_2 = df2.groupby(['seq'])['abs err 2'].mean().to_frame().reset_index()
df = pd.concat([df_group_1, df_group_2], axis=1)
df['mean']=(df['abs err 1']+df['abs err 2'])/2
df_res = df[df['mean']<threshold]
df_res = df_res['seq']
df_res.columns = ['seq','temp']
df_res = df_res['seq']
good_seq=[]
good_rt=[]
for r in df1.iterrows() :
if r[1]['seq'] in df_res.values :
good_rt.append(r[1]['true rt'])
good_seq.append(r[1]['seq'])
return pd.DataFrame({'Sequence' : good_seq, 'Retention time': good_rt})
def add_length(dataframe):
def fonc(a):
a = a.replace('[', '')
a = a.replace(']', '')
a = a.split(',')
a = list(map(int, a))
return np.count_nonzero(np.array(a))
dataframe['length']=dataframe['seq'].map(fonc)
df = pd.read_csv('../output/out_ISA_ISA.csv')
add_length(df)
df['abs_error'] = np.abs(df['rt pred']-df['true rt'])
histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_ISA_ISA.png')
scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_ISA_ISA.png', color=True)
histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_ISA_ISA.png')
df = pd.read_csv('../output/out_prosit_prosit.csv')
add_length(df)
df['abs_error'] = np.abs(df['rt pred']-df['true rt'])
histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_prosit_prosit.png')
scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_prosit_prosit.png', color=True)
histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_prosit_prosit.png')
df = pd.read_csv('../output/out_prosit_ISA_noc.csv')
add_length(df)
df['abs_error'] = np.abs(df['rt pred']-df['true rt'])
histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_prosit_ISA_noc.png')
scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_prosit_ISA_noc.png', color=True)
histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_prosit_ISA_noc.png')
df = pd.read_csv('../output/out_ISA_noc_prosit.csv')
add_length(df)
df['abs_error'] = np.abs(df['rt pred']-df['true rt'])
histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_ISA_noc_prosit.png')
scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_ISA_noc_prosit.png', color=True)
histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_ISA_noc_prosit.png')
## Compare error variation between run
## Prosit column changes affect some peptides more than others (but consistently)
# df_1 = pd.read_csv('output/out_common_ISA_prosit_eval.csv')
# df_2 = pd.read_csv('output/out_common_ISA_prosit_eval_2.csv')
#
# df = select_best_data(df_1, df_2, 7)
# df.to_pickle('database/data_prosit_threshold_7.pkl')
# compare_error(df_1,df_2,save=True,path='fig/custom model res/ISA_prosit_error_variation.png')
......@@ -43,6 +43,7 @@ def train(model, data_train, epoch, optimizer, criterion_rt, metric_rt, wandb=No
def eval(model, data_val, epoch, criterion_rt, metric_rt, wandb=None):
model.eval()
losses_rt = 0.
dist_rt_acc = 0.
for param in model.parameters():
......@@ -136,6 +137,7 @@ def get_n_params(model):
def save_pred(model, data_val, output_path):
data_frame = pd.DataFrame()
model.eval()
for param in model.parameters():
param.requires_grad = False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment