From 89df63ba3592ed0577246560e5719a314a69b62b Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 13 May 2025 15:04:16 +0200
Subject: [PATCH] astral dataset

---
 dataset_comparison.py |   2 +-
 dataset_extraction.py |  66 +++++++++++++++++++++++---
 main_fine_tune.py     | 105 ++++++++++++++++++++++++++++++++++--------
 3 files changed, 148 insertions(+), 25 deletions(-)

diff --git a/dataset_comparison.py b/dataset_comparison.py
index 7e63763..f620d0d 100644
--- a/dataset_comparison.py
+++ b/dataset_comparison.py
@@ -3,7 +3,7 @@ from datasets import load_dataset, DatasetDict
 
 
 df_list =["Wilhelmlab/detectability-proteometools", "Wilhelmlab/detectability-wang","Wilhelmlab/detectability-sinitcyn"]
-df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv')
+df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
 df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
 
 for label_type in ['Classes fragment','Classes precursor', 'Classes MaxLFQ'] :
diff --git a/dataset_extraction.py b/dataset_extraction.py
index c7b8e18..ea2d685 100644
--- a/dataset_extraction.py
+++ b/dataset_extraction.py
@@ -11,7 +11,7 @@ binary_labels = {0: "Non-Flyer", 1: "Flyer"}
 """
 
 
-def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'):
+def build_dataset(coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'):
     df = pd.read_excel('ISA_data/250326_gut_microbiome_std_17_proteomes_data_training_detectability.xlsx')
     df_non_flyer = pd.read_csv('ISA_data/250422_FASTA_17_proteomes_gut_std_ozyme_+_conta_peptides_digested_filtered.csv')
     #No flyer
@@ -38,12 +38,12 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m
     dico_final={}
     # iterate over each group
     for group_name, df_group in df1_grouped:
-        seq = df_group.sort_values(by=[intensity_col])['Stripped.Sequence'].to_list()
-        value_frag = df_group.sort_values(by=[intensity_col])[intensity_col].to_list()
+        seq = df_group.sort_values(by=['Fragment.Quant.Raw'])['Stripped.Sequence'].to_list()
+        value_frag = df_group.sort_values(by=['Fragment.Quant.Raw'])['Fragment.Quant.Raw'].to_list()
         value_prec = df_group.sort_values(by=['Precursor.Quantity'])['Precursor.Quantity'].to_list()
-        value_prec_frag = df_group.sort_values(by=[intensity_col])['Precursor.Quantity'].to_list()
+        value_prec_frag = df_group.sort_values(by=['Fragment.Quant.Raw'])['Precursor.Quantity'].to_list()
         value_maxlfq = df_group.sort_values(by=['MaxLFQ'])['MaxLFQ'].to_list()
-        value_maxlfq_frag = df_group.sort_values(by=[intensity_col])['MaxLFQ'].to_list()
+        value_maxlfq_frag = df_group.sort_values(by=['Fragment.Quant.Raw'])['MaxLFQ'].to_list()
         threshold_weak_flyer_frag = value_frag[int(len(seq) / 3)]
         threshold_medium_flyer_frag = value_frag[int(2*len(seq) / 3)]
         threshold_weak_flyer_prec = value_prec[int(len(seq) / 3)]
@@ -82,5 +82,59 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m
     df_final=df_final[['Sequences','Proteins','Classes fragment','Classes precursor', 'Classes MaxLFQ']]
     df_final.to_csv(f_name, index=False)
     df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage.csv', index=False)
+
+
+
+def build_dataset_astral(coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'):
+    df = pd.read_excel('ISA_data/250505_Flyers_ASTRAL_mix_12_species.xlsx')
+    df_non_flyer = pd.read_excel('ISA_data/250505_Non_flyers_ASTRAL_mix_12_species.xlsx')
+    #No flyer
+    df_non_flyer = df_non_flyer[df_non_flyer['Cystein ?']==0]
+    df_non_flyer = df_non_flyer[pd.isna(df_non_flyer['Miscleavage ?'])]
+    df_non_flyer = df_non_flyer[pd.isna(df_non_flyer['MaxLFQ'])]
+    df_non_flyer['Sequences'] = df_non_flyer['Peptide']
+    df_non_flyer['Proteins'] = df_non_flyer['ProteinID']
+    df_non_flyer=df_non_flyer[['Sequences','Proteins']].drop_duplicates()
+    df_non_flyer['Classes MaxLFQ'] =0
+
+
+    #Flyer
+    df_filtered = df[~(pd.isna(df['Proteotypic ?']))]
+    df_filtered = df_filtered[df_filtered['Coverage']>=coverage_treshold]
+    df_filtered = df_filtered[pd.isna(df_filtered['Miscleavage ? '])]
+    peptide_count=df_filtered.groupby(["Protein.Names"]).size().reset_index(name='counts')
+    filtered_sequence = peptide_count[peptide_count['counts']>=min_peptide]["Protein.Names"]
+    df_filtered = df_filtered[df_filtered["Protein.Names"].isin(filtered_sequence.to_list())]
+
+    df1_grouped = df_filtered.groupby("Protein.Names")
+    dico_final={}
+    # iterate over each group
+    for group_name, df_group in df1_grouped:
+        seq = df_group.sort_values(by=['20250129_ISA_MIX-1_48SPD_001'])['Stripped.Sequence'].to_list()
+        value_maxlfq = df_group.sort_values(by=['20250129_ISA_MIX-1_48SPD_001'])['20250129_ISA_MIX-1_48SPD_001'].to_list()
+        value_maxlfq_frag = df_group.sort_values(by=['20250129_ISA_MIX-1_48SPD_001'])['20250129_ISA_MIX-1_48SPD_001'].to_list()
+        threshold_weak_flyer_maxflq = value_maxlfq[int(len(seq) / 3)]
+        threshold_medium_flyer_maxlfq = value_maxlfq[int(2 * len(seq) / 3)]
+        prot = df_group['Protein.Group'].to_list()[0]
+
+        for i in range(len(seq)):
+
+            if value_maxlfq_frag[i] < threshold_weak_flyer_maxflq :
+                label_maxlfq = 1
+            elif value_maxlfq_frag[i] < threshold_medium_flyer_maxlfq :
+                label_maxlfq = 2
+            else :
+                label_maxlfq = 3
+
+            dico_final[seq[i]] = (prot,label_maxlfq)
+
+    df_final = pd.DataFrame.from_dict(dico_final, orient='index',columns=['Proteins', 'Classes MaxLFQ'])
+    df_final['Sequences']=df_final.index
+    df_final = df_final.reset_index()
+    df_final=df_final[['Sequences','Proteins', 'Classes MaxLFQ']]
+    df_final.to_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv', index=False)
+    df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv', index=False)
+
+
 if __name__ == '__main__':
-    build_dataset( coverage_treshold=20, min_peptide=4, f_name='ISA_data/df_finetune_no_miscleavage.csv')
+    build_dataset_astral(coverage_treshold=20, min_peptide=15)
diff --git a/main_fine_tune.py b/main_fine_tune.py
index e295ebe..e145f6d 100644
--- a/main_fine_tune.py
+++ b/main_fine_tune.py
@@ -9,8 +9,8 @@ from datasets import load_dataset, DatasetDict
 from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report
 
 
-def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=2):
-    df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv')
+def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
+    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
     df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
     df_no_flyer['Classes'] = df_no_flyer[classe_type]
     df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
@@ -28,13 +28,83 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0
     list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :])
     list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :])
 
-    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed)
-    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed)
+    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
+    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
 
     df_train['Proteins']=0
     df_val['Proteins'] = 0
-    df_train.to_csv('temp_fine_tune_df_train.csv', index=False)
-    df_val.to_csv('temp_fine_tune_df_val.csv',index=False)
+    df_train.to_csv('df_preprocessed/df_train_ISA.csv', index=False)
+    df_val.to_csv('df_preprocessed/df_val_ISA_multiclass.csv',index=False)
+
+def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
+    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv')
+    df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
+    df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ']
+    df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
+    df_flyer['Classes']=df_flyer['Classes MaxLFQ']
+    df_flyer = df_flyer[['Sequences','Classes']]
+    #stratified split
+    list_train_split=[]
+    list_val_split =[]
+    for cl in [1,2,3]:
+        df_class = df_flyer[df_flyer['Classes']==cl]
+        class_count = df_class.shape[0]
+        list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
+        list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
+
+    list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :])
+    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :])
+
+    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
+    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
+
+    df_train['Proteins']=0
+    df_val['Proteins'] = 0
+    df_train.to_csv('df_preprocessed/df_train_astral_15.csv', index=False)
+    df_val.to_csv('df_preprocessed/df_val_astral_multiclass_15.csv',index=False)
+
+def create_combine_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
+    df_flyer_astral = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_7.csv')
+    df_no_flyer_astral = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
+    df_no_flyer_astral['Classes'] = df_no_flyer_astral['Classes MaxLFQ']
+    df_no_flyer_astral = df_no_flyer_astral[['Sequences', 'Classes']]
+    df_flyer_astral['Classes']=df_flyer_astral['Classes MaxLFQ']
+    df_flyer_astral = df_flyer_astral[['Sequences','Classes']]
+
+    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
+    df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
+    df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ']
+    df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
+    df_flyer['Classes'] = df_flyer['Classes MaxLFQ']
+    df_flyer = df_flyer[['Sequences', 'Classes']]
+
+
+    #stratified split
+    list_train_split=[]
+    list_val_split =[]
+    for cl in [1,2,3]:
+        df_class_astral = df_flyer_astral[df_flyer_astral['Classes']==cl]
+        class_count_astral = df_class_astral.shape[0]
+        df_class_ISA = df_flyer[df_flyer['Classes'] == cl]
+        class_count_ISA = df_class_ISA.shape[0]
+
+        list_train_split.append(df_class_astral.iloc[:int(class_count_astral*split[0]),:])
+        list_val_split.append(df_class_astral.iloc[int(class_count_astral * split[0]):, :])
+        list_train_split.append(df_class_ISA.iloc[:int(class_count_ISA * split[0]), :])
+        list_val_split.append(df_class_ISA.iloc[int(class_count_ISA * split[0]):, :])
+
+    list_train_split.append(df_no_flyer_astral.iloc[:int(class_count_astral * split[0] * frac_no_fly_train), :])
+    list_val_split.append(df_no_flyer_astral.iloc[df_no_flyer_astral.shape[0]-int(class_count_astral * split[1] * frac_no_fly_val):, :])
+    list_train_split.append(df_no_flyer.iloc[:int(class_count_ISA * split[0] * frac_no_fly_train), :])
+    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0] - int(class_count_ISA * split[1] * frac_no_fly_val):, :])
+
+    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
+    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
+
+    df_train['Proteins']=0
+    df_val['Proteins'] = 0
+    df_train.to_csv('df_preprocessed/df_train_combined_7.csv', index=False)
+    df_val.to_csv('df_preprocessed/df_val_combined_multiclass_7.csv',index=False)
 
 def density_plot(prediction_path,prediction_path_2,criteria='base'):
     df = pd.read_csv(prediction_path)
@@ -76,10 +146,8 @@ def density_plot(prediction_path,prediction_path_2,criteria='base'):
 
 def main():
     total_num_classes = len(CLASSES_LABELS)
-    input_dimension = len(alphabet)
     num_cells = 64
 
-    fine_tuned_model = DetectabilityModel(num_units=num_cells, num_clases=total_num_classes)
     load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL'
     fine_tuned_model = DetectabilityModel(num_units=num_cells,
                                           num_clases=total_num_classes)
@@ -94,8 +162,8 @@ def main():
 
     print('Initialising dataset')
     ## Data init
-    fine_tune_data = DetectabilityDataset(data_source='temp_fine_tune_df_train.csv',
-                                              val_data_source='temp_fine_tune_df_val.csv',
+    fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_combined_15.csv',
+                                              val_data_source='df_preprocessed/df_val_combined_multiclass_15.csv',
                                               data_format='csv',
                                               max_seq_len=max_pep_length,
                                               label_column="Classes",
@@ -114,7 +182,7 @@ def main():
                                                    verbose=1,
                                                    patience=5)
 
-    model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'
+    model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined'
 
     model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,
                                                              monitor='val_loss',
@@ -131,13 +199,13 @@ def main():
 
     history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
                                               validation_data=fine_tune_data.tensor_val_data,
-                                              epochs=1,
+                                              epochs=150,
                                               callbacks=[callback_FT, model_checkpoint_FT])
 
     ## Loading best model weights
 
-    # model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'
-    model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
+    model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data
+    # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
 
     fine_tuned_model.load_weights(model_save_path_FT)
 
@@ -170,16 +238,17 @@ def main():
     report_FT = DetectabilityReport(test_targets_FT_one_hot,
                                     predictions_FT,
                                     test_data_df_FT,
-                                    output_path='./output/report_on_ISA (Base model categorical train, binary val )',
+                                    output_path='./output/report_on_combined_15 (Fine tuned model (combined_10) categorical train, categorical val)',
                                     history=history_fine_tuned,
                                     rank_by_prot=True,
                                     threshold=None,
-                                    name_of_dataset='ISA val dataset (binary balanced)',
-                                    name_of_model='Base model (ISA)')
+                                    name_of_dataset='combined_15 val dataset (categorical balanced)',
+                                    name_of_model='Fine tuned model (combined_15)')
 
     report_FT.generate_report()
 
 if __name__ == '__main__':
-    create_ISA_dataset()
+    # create_astral_dataset()
+    # create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1)
     main()
     # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')
-- 
GitLab