From 7f5c53cade39b57425294984031671bf9589d370 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 8 Oct 2024 17:01:17 +0200
Subject: [PATCH] selection colomn invariant prosit data

---
 alignement.py | 21 ++++++---------------
 data_viz.py   | 50 +++++++++++++++++++++++++++++++++++++++++---------
 2 files changed, 47 insertions(+), 24 deletions(-)

diff --git a/alignement.py b/alignement.py
index f1d6f58..383e6e5 100644
--- a/alignement.py
+++ b/alignement.py
@@ -1,6 +1,8 @@
 import numpy as np
 import pandas as pd
 from loess.loess_1d import loess_1d
+import scipy as sp
+from sklearn.metrics import r2_score
 from sympy.abc import alpha
 
 import dataloader
@@ -79,12 +81,8 @@ def filter_cysteine(df, col):
 def compare_include_df(df, sub_df, save = True, path = 'temp.png'):
     df_value_list = []
     df_sub_value_list=[]
-    i=0
     for r in sub_df.iterrows() :
-        print(i)
-        i+=1
         try :
-
             df_value_list.append(df[df['Sequence']==r[1]['Sequence']]['Retention time'].reset_index(drop=True)[0])
             df_sub_value_list.append(r[1]['Retention time'])
         except:
@@ -92,6 +90,10 @@ def compare_include_df(df, sub_df, save = True, path = 'temp.png'):
 
     fig, ax = plt.subplots()
     ax.scatter(df_sub_value_list, df_value_list)
+    x = np.array([min(df_value_list), max(df_value_list)])
+    linreg = sp.stats.linregress(df_value_list, df_sub_value_list)
+    ax.annotate("r-squared = {:.3f}".format(r2_score(df_value_list, df_sub_value_list)), (0, 1))
+    plt.plot(x, linreg.intercept + linreg.slope * x, 'r')
 
     if save :
         plt.savefig(path)
@@ -194,14 +196,3 @@ df_ISA = pd.read_pickle('database/data_ISA_dual_align.pkl')
 df_diann_aligned = align(df_diann, df_ori)
 
 df_value_list, df_sub_value_list = compare_include_df(df_diann_aligned, df_ISA, True)
-
-import scipy as sp
-from sklearn.metrics import r2_score
-fig, ax = plt.subplots()
-ax.scatter(df_sub_value_list, df_value_list, s=0.1,alpha=0.1)
-x = np.array([min(df_value_list), max(df_value_list)])
-linreg = sp.stats.linregress(df_value_list, df_sub_value_list)
-ax.annotate("r-squared = {:.3f}".format(r2_score(df_value_list, df_sub_value_list)), (0, 1))
-plt.plot(x, linreg.intercept + linreg.slope * x, 'r')
-plt.savefig('scatter_DIANN-ISA_aligned_on_prosit.png')
-plt.clf()
diff --git a/data_viz.py b/data_viz.py
index 87b641d..b00a12b 100644
--- a/data_viz.py
+++ b/data_viz.py
@@ -226,8 +226,8 @@ def histo_length_by_error(dataframe, bins, display=False, save=False, path=None)
         plt.savefig(path)
 
 def compare_error(df1, df2, display=False, save=False, path=None):
-    df1['abs err 1'] = abs(df1['rt pred'] - df1['true rt'])
-    df2['abs err 2'] = abs(df2['rt pred'] - df2['true rt'])
+    df1['abs err 1'] = df1['rt pred'] - df1['true rt']
+    df2['abs err 2'] = df2['rt pred'] - df2['true rt']
     df_group_1 = df1.groupby(['seq'])['abs err 1'].mean().to_frame().reset_index()
     df_group_2 = df2.groupby(['seq'])['abs err 2'].mean().to_frame().reset_index()
     df = pd.concat([df_group_1,df_group_2],axis=1)
@@ -244,6 +244,28 @@ def compare_error(df1, df2, display=False, save=False, path=None):
     if save:
         plt.savefig(path)
 
+def select_best_data(df1,df2,threshold):
+    df1['abs err 1'] = abs(df1['rt pred'] - df1['true rt'])
+    df2['abs err 2'] = abs(df2['rt pred'] - df2['true rt'])
+    df_group_1 = df1.groupby(['seq'])['abs err 1'].mean().to_frame().reset_index()
+    df_group_2 = df2.groupby(['seq'])['abs err 2'].mean().to_frame().reset_index()
+    df = pd.concat([df_group_1, df_group_2], axis=1)
+    df['mean']=(df['abs err 1']+df['abs err 2'])/2
+    df_res = df[df['mean']<threshold]
+    print(df_res.size)
+    df_res = df_res['seq']
+    df_res.columns = ['seq','temp']
+    df_res = df_res['seq']
+
+    good_seq=[]
+    good_rt=[]
+    for r in df1.iterrows() :
+        if r[1]['seq'] in df_res.values :
+            good_rt.append(r[1]['true rt'])
+            good_seq.append(r[1]['seq'])
+    return pd.DataFrame({'Sequence' : good_seq, 'Retention time': good_rt})
+
+
 
 def add_length(dataframe):
     def fonc(a):
@@ -269,12 +291,12 @@ def add_length(dataframe):
 # scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_prosit_eval.png', color=True)
 # histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_prosit_eval.png')
 #
-df = pd.read_csv('output/out_common_prosit_ISA_eval_3.csv')
-add_length(df)
-df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
-histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_ISA_eval_3.png')
-scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_ISA_eval_3.png', color=True)
-histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_ISA_eval_3.png')
+# df = pd.read_csv('output/out_common_prosit_ISA_eval_3.csv')
+# add_length(df)
+# df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
+# histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_ISA_eval_3.png')
+# scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_ISA_eval_3.png', color=True)
+# histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_ISA_eval_3.png')
 
 # df = pd.read_csv('output/out_common_ISA_prosit_eval.csv')
 # add_length(df)
@@ -282,5 +304,15 @@ histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom mo
 # histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_ISA_prosit_eval.png')
 # scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_ISA_prosit_eval.png', color=True, col = 'seq')
 # histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_ISA_prosit_eval.png')
-#
+
+## Compare error variation between run
+## Prosit column changes affect some peptides more than others (but consistently)
+df_1 = pd.read_csv('output/out_common_ISA_prosit_eval.csv')
+df_2 = pd.read_csv('output/out_common_ISA_prosit_eval_2.csv')
+
+df = select_best_data(df_1, df_2, 3)
+df.to_pickle('database/data_prosit_threshold_3.pkl')
+# compare_error(df_1,df_2,save=True,path='fig/custom model res/ISA_prosit_error_variation.png')
+
+
 
-- 
GitLab