From 3ccde8cdc08b2a830afb566e3620654cfd0ad79a Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Thu, 28 Nov 2024 10:41:22 +0100
Subject: [PATCH] model.eval() and dataviz

---
 data/data_viz.py | 205 +++++++++++++++++++++++++++++++++++++++++++++++
 main.py          |   2 +
 2 files changed, 207 insertions(+)
 create mode 100644 data/data_viz.py

diff --git a/data/data_viz.py b/data/data_viz.py
new file mode 100644
index 0000000..c8d52dd
--- /dev/null
+++ b/data/data_viz.py
@@ -0,0 +1,205 @@
+import scipy as sp
+from sklearn.metrics import r2_score
+import matplotlib.pyplot as plt
+import numpy as np
+import random
+import pandas as pd
+import matplotlib.colors as mcolors
+
+
+def histo_abs_error(dataframe, display=False, save=False, path=None):
+    points = dataframe['abs_error']
+
+    ## combine these different collections into a list
+    data_to_plot = [points]
+
+
+    # Create a figure instance
+    fig, ax = plt.subplots()
+
+    # Create the boxplot
+    ax.set_xlabel('abs error')
+    ax.violinplot(data_to_plot, vert=False, side='high', showmedians=True, quantiles=[0.95])
+    ax.set_xlim(0,4)
+    if display :
+        plt.show()
+
+    if save :
+        plt.savefig(path)
+
+
+def random_color_deterministic(df, column):
+
+    def rd10(str):
+        color = list(mcolors.CSS4_COLORS)
+        random.seed(str)
+        return color[random.randint(0,147)]
+
+    df['color']=df[column].map(rd10)
+
+def scatter_rt(dataframe, display=False, save=False, path=None, color = False, col = 'seq'):
+    fig, ax = plt.subplots()
+    if color :
+        random_color_deterministic(dataframe, col)
+        ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1, color = dataframe['color'])
+    else :
+        ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1)
+    ax.set_xlabel('true RT')
+    ax.set_ylabel('pred RT')
+    x = np.array([min(dataframe['true rt']), max(dataframe['true rt'])])
+    linreg = sp.stats.linregress(dataframe['true rt'], dataframe['rt pred'])
+    ax.annotate("r-squared = {:.3f}".format(r2_score(dataframe['true rt'], dataframe['rt pred'])), (0, 1))
+    plt.plot(x, linreg.intercept + linreg.slope * x, 'r')
+    if display :
+        plt.show()
+
+    if save :
+        plt.savefig(path)
+
+
+
+
+def histo_abs_error_by_length(dataframe, display=False, save=False, path=None):
+    data_to_plot =[]
+    max_length = max(dataframe['length'])
+    min_length = min(dataframe['length'])
+    for l in range(min_length, max_length):
+        data_to_plot.append(dataframe['abs_error'].where(dataframe['length']==l))
+
+    # data_to_plot.append()
+
+
+    fig, ax = plt.subplots()
+
+    # Create the boxplot
+    bp = ax.violinplot(data_to_plot, vert=True, side='low')
+    if display:
+        plt.show()
+
+    if save:
+        plt.savefig(path)
+
+def running_mean(x, N):
+    cumsum = np.cumsum(np.insert(x, 0, 0))
+    return (cumsum[N:] - cumsum[:-N]) / float(N)
+
+def histo_length_by_error(dataframe, bins, display=False, save=False, path=None):
+    data_to_plot = []
+    quanti = []
+    max_error = max(dataframe['abs_error'])
+    inter = np.linspace(0, max_error, num=bins+1)
+    inter_m = running_mean(inter, 2)
+
+    inter_labels = list(map(lambda x : round(x,2),inter_m))
+    inter_labels.insert(0,0)
+    for i in range(bins):
+        a = dataframe.loc[(inter[i] < dataframe['abs_error']) & (dataframe['abs_error'] < inter[i+1])]['length']
+        if len(a)>0:
+            data_to_plot.append(a)
+            quanti.append(0.95)
+        else :
+            data_to_plot.append([0])
+            quanti.append(0.95)
+
+
+    fig, ax = plt.subplots()
+
+    # Create the boxplot
+    ax.violinplot(data_to_plot, vert=True, side='high', showmedians=True)
+    ax.set_ylabel('length')
+    ax.set_xticks(range(len(inter)),inter_labels)
+    if display:
+        plt.show()
+
+    if save:
+        plt.savefig(path)
+
+def compare_error(df1, df2, display=False, save=False, path=None):
+    df1['abs err 1'] = df1['rt pred'] - df1['true rt']
+    df2['abs err 2'] = df2['rt pred'] - df2['true rt']
+    df_group_1 = df1.groupby(['seq'])['abs err 1'].mean().to_frame().reset_index()
+    df_group_2 = df2.groupby(['seq'])['abs err 2'].mean().to_frame().reset_index()
+    df = pd.concat([df_group_1,df_group_2],axis=1)
+
+    fig, ax = plt.subplots()
+    ax.scatter(df['abs err 1'], df['abs err 2'], s=0.1, alpha=0.05)
+
+    plt.savefig('temp.png')
+
+
+    if display:
+        plt.show()
+
+    if save:
+        plt.savefig(path)
+
+def select_best_data(df1,df2,threshold):
+    df1['abs err 1'] = abs(df1['rt pred'] - df1['true rt'])
+    df2['abs err 2'] = abs(df2['rt pred'] - df2['true rt'])
+    df_group_1 = df1.groupby(['seq'])['abs err 1'].mean().to_frame().reset_index()
+    df_group_2 = df2.groupby(['seq'])['abs err 2'].mean().to_frame().reset_index()
+    df = pd.concat([df_group_1, df_group_2], axis=1)
+    df['mean']=(df['abs err 1']+df['abs err 2'])/2
+    df_res = df[df['mean']<threshold]
+    df_res = df_res['seq']
+    df_res.columns = ['seq','temp']
+    df_res = df_res['seq']
+    good_seq=[]
+    good_rt=[]
+    for r in df1.iterrows() :
+        if r[1]['seq'] in df_res.values :
+            good_rt.append(r[1]['true rt'])
+            good_seq.append(r[1]['seq'])
+    return pd.DataFrame({'Sequence' : good_seq, 'Retention time': good_rt})
+
+
+
+def add_length(dataframe):
+    def fonc(a):
+        a = a.replace('[', '')
+        a = a.replace(']', '')
+        a = a.split(',')
+        a = list(map(int, a))
+        return np.count_nonzero(np.array(a))
+    dataframe['length']=dataframe['seq'].map(fonc)
+
+
+df = pd.read_csv('../output/out_ISA_ISA.csv')
+add_length(df)
+df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
+histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_ISA_ISA.png')
+scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_ISA_ISA.png', color=True)
+histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_ISA_ISA.png')
+
+df = pd.read_csv('../output/out_prosit_prosit.csv')
+add_length(df)
+df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
+histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_prosit_prosit.png')
+scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_prosit_prosit.png', color=True)
+histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_prosit_prosit.png')
+
+df = pd.read_csv('../output/out_prosit_ISA_noc.csv')
+add_length(df)
+df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
+histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_prosit_ISA_noc.png')
+scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_prosit_ISA_noc.png', color=True)
+histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_prosit_ISA_noc.png')
+
+df = pd.read_csv('../output/out_ISA_noc_prosit.csv')
+add_length(df)
+df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
+histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_ISA_noc_prosit.png')
+scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_ISA_noc_prosit.png', color=True)
+histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_ISA_noc_prosit.png')
+
+
+
+## Compare error variation between run
+## Prosit column changes affect some peptides more than others (but consistently)
+# df_1 = pd.read_csv('output/out_common_ISA_prosit_eval.csv')
+# df_2 = pd.read_csv('output/out_common_ISA_prosit_eval_2.csv')
+#
+# df = select_best_data(df_1, df_2, 7)
+# df.to_pickle('database/data_prosit_threshold_7.pkl')
+# compare_error(df_1,df_2,save=True,path='fig/custom model res/ISA_prosit_error_variation.png')
+
diff --git a/main.py b/main.py
index 835bb68..1345ab8 100644
--- a/main.py
+++ b/main.py
@@ -43,6 +43,7 @@ def train(model, data_train, epoch, optimizer, criterion_rt, metric_rt, wandb=No
 
 
 def eval(model, data_val, epoch, criterion_rt, metric_rt, wandb=None):
+    model.eval()
     losses_rt = 0.
     dist_rt_acc = 0.
     for param in model.parameters():
@@ -136,6 +137,7 @@ def get_n_params(model):
 
 def save_pred(model, data_val, output_path):
     data_frame = pd.DataFrame()
+    model.eval()
     for param in model.parameters():
         param.requires_grad = False
 
-- 
GitLab