From 3ccde8cdc08b2a830afb566e3620654cfd0ad79a Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Thu, 28 Nov 2024 10:41:22 +0100 Subject: [PATCH] model.eval() and dataviz --- data/data_viz.py | 205 +++++++++++++++++++++++++++++++++++++++++++++++ main.py | 2 + 2 files changed, 207 insertions(+) create mode 100644 data/data_viz.py diff --git a/data/data_viz.py b/data/data_viz.py new file mode 100644 index 0000000..c8d52dd --- /dev/null +++ b/data/data_viz.py @@ -0,0 +1,205 @@ +import scipy as sp +from sklearn.metrics import r2_score +import matplotlib.pyplot as plt +import numpy as np +import random +import pandas as pd +import matplotlib.colors as mcolors + + +def histo_abs_error(dataframe, display=False, save=False, path=None): + points = dataframe['abs_error'] + + ## combine these different collections into a list + data_to_plot = [points] + + + # Create a figure instance + fig, ax = plt.subplots() + + # Create the boxplot + ax.set_xlabel('abs error') + ax.violinplot(data_to_plot, vert=False, side='high', showmedians=True, quantiles=[0.95]) + ax.set_xlim(0,4) + if display : + plt.show() + + if save : + plt.savefig(path) + + +def random_color_deterministic(df, column): + + def rd10(str): + color = list(mcolors.CSS4_COLORS) + random.seed(str) + return color[random.randint(0,147)] + + df['color']=df[column].map(rd10) + +def scatter_rt(dataframe, display=False, save=False, path=None, color = False, col = 'seq'): + fig, ax = plt.subplots() + if color : + random_color_deterministic(dataframe, col) + ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1, color = dataframe['color']) + else : + ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1) + ax.set_xlabel('true RT') + ax.set_ylabel('pred RT') + x = np.array([min(dataframe['true rt']), max(dataframe['true rt'])]) + linreg = sp.stats.linregress(dataframe['true rt'], dataframe['rt pred']) + ax.annotate("r-squared = {:.3f}".format(r2_score(dataframe['true rt'], dataframe['rt pred'])), (0, 1)) + plt.plot(x, linreg.intercept + linreg.slope * x, 'r') + if display : + plt.show() + + if save : + plt.savefig(path) + + + + +def histo_abs_error_by_length(dataframe, display=False, save=False, path=None): + data_to_plot =[] + max_length = max(dataframe['length']) + min_length = min(dataframe['length']) + for l in range(min_length, max_length): + data_to_plot.append(dataframe['abs_error'].where(dataframe['length']==l)) + + # data_to_plot.append() + + + fig, ax = plt.subplots() + + # Create the boxplot + bp = ax.violinplot(data_to_plot, vert=True, side='low') + if display: + plt.show() + + if save: + plt.savefig(path) + +def running_mean(x, N): + cumsum = np.cumsum(np.insert(x, 0, 0)) + return (cumsum[N:] - cumsum[:-N]) / float(N) + +def histo_length_by_error(dataframe, bins, display=False, save=False, path=None): + data_to_plot = [] + quanti = [] + max_error = max(dataframe['abs_error']) + inter = np.linspace(0, max_error, num=bins+1) + inter_m = running_mean(inter, 2) + + inter_labels = list(map(lambda x : round(x,2),inter_m)) + inter_labels.insert(0,0) + for i in range(bins): + a = dataframe.loc[(inter[i] < dataframe['abs_error']) & (dataframe['abs_error'] < inter[i+1])]['length'] + if len(a)>0: + data_to_plot.append(a) + quanti.append(0.95) + else : + data_to_plot.append([0]) + quanti.append(0.95) + + + fig, ax = plt.subplots() + + # Create the boxplot + ax.violinplot(data_to_plot, vert=True, side='high', showmedians=True) + ax.set_ylabel('length') + ax.set_xticks(range(len(inter)),inter_labels) + if display: + plt.show() + + if save: + plt.savefig(path) + +def compare_error(df1, df2, display=False, save=False, path=None): + df1['abs err 1'] = df1['rt pred'] - df1['true rt'] + df2['abs err 2'] = df2['rt pred'] - df2['true rt'] + df_group_1 = df1.groupby(['seq'])['abs err 1'].mean().to_frame().reset_index() + df_group_2 = df2.groupby(['seq'])['abs err 2'].mean().to_frame().reset_index() + df = pd.concat([df_group_1,df_group_2],axis=1) + + fig, ax = plt.subplots() + ax.scatter(df['abs err 1'], df['abs err 2'], s=0.1, alpha=0.05) + + plt.savefig('temp.png') + + + if display: + plt.show() + + if save: + plt.savefig(path) + +def select_best_data(df1,df2,threshold): + df1['abs err 1'] = abs(df1['rt pred'] - df1['true rt']) + df2['abs err 2'] = abs(df2['rt pred'] - df2['true rt']) + df_group_1 = df1.groupby(['seq'])['abs err 1'].mean().to_frame().reset_index() + df_group_2 = df2.groupby(['seq'])['abs err 2'].mean().to_frame().reset_index() + df = pd.concat([df_group_1, df_group_2], axis=1) + df['mean']=(df['abs err 1']+df['abs err 2'])/2 + df_res = df[df['mean']<threshold] + df_res = df_res['seq'] + df_res.columns = ['seq','temp'] + df_res = df_res['seq'] + good_seq=[] + good_rt=[] + for r in df1.iterrows() : + if r[1]['seq'] in df_res.values : + good_rt.append(r[1]['true rt']) + good_seq.append(r[1]['seq']) + return pd.DataFrame({'Sequence' : good_seq, 'Retention time': good_rt}) + + + +def add_length(dataframe): + def fonc(a): + a = a.replace('[', '') + a = a.replace(']', '') + a = a.split(',') + a = list(map(int, a)) + return np.count_nonzero(np.array(a)) + dataframe['length']=dataframe['seq'].map(fonc) + + +df = pd.read_csv('../output/out_ISA_ISA.csv') +add_length(df) +df['abs_error'] = np.abs(df['rt pred']-df['true rt']) +histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_ISA_ISA.png') +scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_ISA_ISA.png', color=True) +histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_ISA_ISA.png') + +df = pd.read_csv('../output/out_prosit_prosit.csv') +add_length(df) +df['abs_error'] = np.abs(df['rt pred']-df['true rt']) +histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_prosit_prosit.png') +scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_prosit_prosit.png', color=True) +histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_prosit_prosit.png') + +df = pd.read_csv('../output/out_prosit_ISA_noc.csv') +add_length(df) +df['abs_error'] = np.abs(df['rt pred']-df['true rt']) +histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_prosit_ISA_noc.png') +scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_prosit_ISA_noc.png', color=True) +histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_prosit_ISA_noc.png') + +df = pd.read_csv('../output/out_ISA_noc_prosit.csv') +add_length(df) +df['abs_error'] = np.abs(df['rt pred']-df['true rt']) +histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_ISA_noc_prosit.png') +scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_ISA_noc_prosit.png', color=True) +histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_ISA_noc_prosit.png') + + + +## Compare error variation between run +## Prosit column changes affect some peptides more than others (but consistently) +# df_1 = pd.read_csv('output/out_common_ISA_prosit_eval.csv') +# df_2 = pd.read_csv('output/out_common_ISA_prosit_eval_2.csv') +# +# df = select_best_data(df_1, df_2, 7) +# df.to_pickle('database/data_prosit_threshold_7.pkl') +# compare_error(df_1,df_2,save=True,path='fig/custom model res/ISA_prosit_error_variation.png') + diff --git a/main.py b/main.py index 835bb68..1345ab8 100644 --- a/main.py +++ b/main.py @@ -43,6 +43,7 @@ def train(model, data_train, epoch, optimizer, criterion_rt, metric_rt, wandb=No def eval(model, data_val, epoch, criterion_rt, metric_rt, wandb=None): + model.eval() losses_rt = 0. dist_rt_acc = 0. for param in model.parameters(): @@ -136,6 +137,7 @@ def get_n_params(model): def save_pred(model, data_val, output_path): data_frame = pd.DataFrame() + model.eval() for param in model.parameters(): param.requires_grad = False -- GitLab