From 89df63ba3592ed0577246560e5719a314a69b62b Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 13 May 2025 15:04:16 +0200 Subject: [PATCH] astral dataset --- dataset_comparison.py | 2 +- dataset_extraction.py | 66 +++++++++++++++++++++++--- main_fine_tune.py | 105 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 148 insertions(+), 25 deletions(-) diff --git a/dataset_comparison.py b/dataset_comparison.py index 7e63763..f620d0d 100644 --- a/dataset_comparison.py +++ b/dataset_comparison.py @@ -3,7 +3,7 @@ from datasets import load_dataset, DatasetDict df_list =["Wilhelmlab/detectability-proteometools", "Wilhelmlab/detectability-wang","Wilhelmlab/detectability-sinitcyn"] -df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv') +df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv') df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv') for label_type in ['Classes fragment','Classes precursor', 'Classes MaxLFQ'] : diff --git a/dataset_extraction.py b/dataset_extraction.py index c7b8e18..ea2d685 100644 --- a/dataset_extraction.py +++ b/dataset_extraction.py @@ -11,7 +11,7 @@ binary_labels = {0: "Non-Flyer", 1: "Flyer"} """ -def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'): +def build_dataset(coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'): df = pd.read_excel('ISA_data/250326_gut_microbiome_std_17_proteomes_data_training_detectability.xlsx') df_non_flyer = pd.read_csv('ISA_data/250422_FASTA_17_proteomes_gut_std_ozyme_+_conta_peptides_digested_filtered.csv') #No flyer @@ -38,12 +38,12 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m dico_final={} # iterate over each group for group_name, df_group in df1_grouped: - seq = df_group.sort_values(by=[intensity_col])['Stripped.Sequence'].to_list() - value_frag = df_group.sort_values(by=[intensity_col])[intensity_col].to_list() + seq = df_group.sort_values(by=['Fragment.Quant.Raw'])['Stripped.Sequence'].to_list() + value_frag = df_group.sort_values(by=['Fragment.Quant.Raw'])['Fragment.Quant.Raw'].to_list() value_prec = df_group.sort_values(by=['Precursor.Quantity'])['Precursor.Quantity'].to_list() - value_prec_frag = df_group.sort_values(by=[intensity_col])['Precursor.Quantity'].to_list() + value_prec_frag = df_group.sort_values(by=['Fragment.Quant.Raw'])['Precursor.Quantity'].to_list() value_maxlfq = df_group.sort_values(by=['MaxLFQ'])['MaxLFQ'].to_list() - value_maxlfq_frag = df_group.sort_values(by=[intensity_col])['MaxLFQ'].to_list() + value_maxlfq_frag = df_group.sort_values(by=['Fragment.Quant.Raw'])['MaxLFQ'].to_list() threshold_weak_flyer_frag = value_frag[int(len(seq) / 3)] threshold_medium_flyer_frag = value_frag[int(2*len(seq) / 3)] threshold_weak_flyer_prec = value_prec[int(len(seq) / 3)] @@ -82,5 +82,59 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m df_final=df_final[['Sequences','Proteins','Classes fragment','Classes precursor', 'Classes MaxLFQ']] df_final.to_csv(f_name, index=False) df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage.csv', index=False) + + + +def build_dataset_astral(coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'): + df = pd.read_excel('ISA_data/250505_Flyers_ASTRAL_mix_12_species.xlsx') + df_non_flyer = pd.read_excel('ISA_data/250505_Non_flyers_ASTRAL_mix_12_species.xlsx') + #No flyer + df_non_flyer = df_non_flyer[df_non_flyer['Cystein ?']==0] + df_non_flyer = df_non_flyer[pd.isna(df_non_flyer['Miscleavage ?'])] + df_non_flyer = df_non_flyer[pd.isna(df_non_flyer['MaxLFQ'])] + df_non_flyer['Sequences'] = df_non_flyer['Peptide'] + df_non_flyer['Proteins'] = df_non_flyer['ProteinID'] + df_non_flyer=df_non_flyer[['Sequences','Proteins']].drop_duplicates() + df_non_flyer['Classes MaxLFQ'] =0 + + + #Flyer + df_filtered = df[~(pd.isna(df['Proteotypic ?']))] + df_filtered = df_filtered[df_filtered['Coverage']>=coverage_treshold] + df_filtered = df_filtered[pd.isna(df_filtered['Miscleavage ? '])] + peptide_count=df_filtered.groupby(["Protein.Names"]).size().reset_index(name='counts') + filtered_sequence = peptide_count[peptide_count['counts']>=min_peptide]["Protein.Names"] + df_filtered = df_filtered[df_filtered["Protein.Names"].isin(filtered_sequence.to_list())] + + df1_grouped = df_filtered.groupby("Protein.Names") + dico_final={} + # iterate over each group + for group_name, df_group in df1_grouped: + seq = df_group.sort_values(by=['20250129_ISA_MIX-1_48SPD_001'])['Stripped.Sequence'].to_list() + value_maxlfq = df_group.sort_values(by=['20250129_ISA_MIX-1_48SPD_001'])['20250129_ISA_MIX-1_48SPD_001'].to_list() + value_maxlfq_frag = df_group.sort_values(by=['20250129_ISA_MIX-1_48SPD_001'])['20250129_ISA_MIX-1_48SPD_001'].to_list() + threshold_weak_flyer_maxflq = value_maxlfq[int(len(seq) / 3)] + threshold_medium_flyer_maxlfq = value_maxlfq[int(2 * len(seq) / 3)] + prot = df_group['Protein.Group'].to_list()[0] + + for i in range(len(seq)): + + if value_maxlfq_frag[i] < threshold_weak_flyer_maxflq : + label_maxlfq = 1 + elif value_maxlfq_frag[i] < threshold_medium_flyer_maxlfq : + label_maxlfq = 2 + else : + label_maxlfq = 3 + + dico_final[seq[i]] = (prot,label_maxlfq) + + df_final = pd.DataFrame.from_dict(dico_final, orient='index',columns=['Proteins', 'Classes MaxLFQ']) + df_final['Sequences']=df_final.index + df_final = df_final.reset_index() + df_final=df_final[['Sequences','Proteins', 'Classes MaxLFQ']] + df_final.to_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv', index=False) + df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv', index=False) + + if __name__ == '__main__': - build_dataset( coverage_treshold=20, min_peptide=4, f_name='ISA_data/df_finetune_no_miscleavage.csv') + build_dataset_astral(coverage_treshold=20, min_peptide=15) diff --git a/main_fine_tune.py b/main_fine_tune.py index e295ebe..e145f6d 100644 --- a/main_fine_tune.py +++ b/main_fine_tune.py @@ -9,8 +9,8 @@ from datasets import load_dataset, DatasetDict from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report -def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=2): - df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv') +def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1): + df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv') df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv') df_no_flyer['Classes'] = df_no_flyer[classe_type] df_no_flyer = df_no_flyer[['Sequences', 'Classes']] @@ -28,13 +28,83 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0 list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :]) list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :]) - df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) - df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) + df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle + df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle df_train['Proteins']=0 df_val['Proteins'] = 0 - df_train.to_csv('temp_fine_tune_df_train.csv', index=False) - df_val.to_csv('temp_fine_tune_df_val.csv',index=False) + df_train.to_csv('df_preprocessed/df_train_ISA.csv', index=False) + df_val.to_csv('df_preprocessed/df_val_ISA_multiclass.csv',index=False) + +def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1): + df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv') + df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv') + df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ'] + df_no_flyer = df_no_flyer[['Sequences', 'Classes']] + df_flyer['Classes']=df_flyer['Classes MaxLFQ'] + df_flyer = df_flyer[['Sequences','Classes']] + #stratified split + list_train_split=[] + list_val_split =[] + for cl in [1,2,3]: + df_class = df_flyer[df_flyer['Classes']==cl] + class_count = df_class.shape[0] + list_train_split.append(df_class.iloc[:int(class_count*split[0]),:]) + list_val_split.append(df_class.iloc[int(class_count * split[0]):, :]) + + list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :]) + list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :]) + + df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle + df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle + + df_train['Proteins']=0 + df_val['Proteins'] = 0 + df_train.to_csv('df_preprocessed/df_train_astral_15.csv', index=False) + df_val.to_csv('df_preprocessed/df_val_astral_multiclass_15.csv',index=False) + +def create_combine_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1): + df_flyer_astral = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_7.csv') + df_no_flyer_astral = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv') + df_no_flyer_astral['Classes'] = df_no_flyer_astral['Classes MaxLFQ'] + df_no_flyer_astral = df_no_flyer_astral[['Sequences', 'Classes']] + df_flyer_astral['Classes']=df_flyer_astral['Classes MaxLFQ'] + df_flyer_astral = df_flyer_astral[['Sequences','Classes']] + + df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv') + df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv') + df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ'] + df_no_flyer = df_no_flyer[['Sequences', 'Classes']] + df_flyer['Classes'] = df_flyer['Classes MaxLFQ'] + df_flyer = df_flyer[['Sequences', 'Classes']] + + + #stratified split + list_train_split=[] + list_val_split =[] + for cl in [1,2,3]: + df_class_astral = df_flyer_astral[df_flyer_astral['Classes']==cl] + class_count_astral = df_class_astral.shape[0] + df_class_ISA = df_flyer[df_flyer['Classes'] == cl] + class_count_ISA = df_class_ISA.shape[0] + + list_train_split.append(df_class_astral.iloc[:int(class_count_astral*split[0]),:]) + list_val_split.append(df_class_astral.iloc[int(class_count_astral * split[0]):, :]) + list_train_split.append(df_class_ISA.iloc[:int(class_count_ISA * split[0]), :]) + list_val_split.append(df_class_ISA.iloc[int(class_count_ISA * split[0]):, :]) + + list_train_split.append(df_no_flyer_astral.iloc[:int(class_count_astral * split[0] * frac_no_fly_train), :]) + list_val_split.append(df_no_flyer_astral.iloc[df_no_flyer_astral.shape[0]-int(class_count_astral * split[1] * frac_no_fly_val):, :]) + list_train_split.append(df_no_flyer.iloc[:int(class_count_ISA * split[0] * frac_no_fly_train), :]) + list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0] - int(class_count_ISA * split[1] * frac_no_fly_val):, :]) + + df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle + df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle + + df_train['Proteins']=0 + df_val['Proteins'] = 0 + df_train.to_csv('df_preprocessed/df_train_combined_7.csv', index=False) + df_val.to_csv('df_preprocessed/df_val_combined_multiclass_7.csv',index=False) def density_plot(prediction_path,prediction_path_2,criteria='base'): df = pd.read_csv(prediction_path) @@ -76,10 +146,8 @@ def density_plot(prediction_path,prediction_path_2,criteria='base'): def main(): total_num_classes = len(CLASSES_LABELS) - input_dimension = len(alphabet) num_cells = 64 - fine_tuned_model = DetectabilityModel(num_units=num_cells, num_clases=total_num_classes) load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' fine_tuned_model = DetectabilityModel(num_units=num_cells, num_clases=total_num_classes) @@ -94,8 +162,8 @@ def main(): print('Initialising dataset') ## Data init - fine_tune_data = DetectabilityDataset(data_source='temp_fine_tune_df_train.csv', - val_data_source='temp_fine_tune_df_val.csv', + fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_combined_15.csv', + val_data_source='df_preprocessed/df_val_combined_multiclass_15.csv', data_format='csv', max_seq_len=max_pep_length, label_column="Classes", @@ -114,7 +182,7 @@ def main(): verbose=1, patience=5) - model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability' + model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT, monitor='val_loss', @@ -131,13 +199,13 @@ def main(): history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data, validation_data=fine_tune_data.tensor_val_data, - epochs=1, + epochs=150, callbacks=[callback_FT, model_checkpoint_FT]) ## Loading best model weights - # model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability' - model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model + model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data + # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model fine_tuned_model.load_weights(model_save_path_FT) @@ -170,16 +238,17 @@ def main(): report_FT = DetectabilityReport(test_targets_FT_one_hot, predictions_FT, test_data_df_FT, - output_path='./output/report_on_ISA (Base model categorical train, binary val )', + output_path='./output/report_on_combined_15 (Fine tuned model (combined_10) categorical train, categorical val)', history=history_fine_tuned, rank_by_prot=True, threshold=None, - name_of_dataset='ISA val dataset (binary balanced)', - name_of_model='Base model (ISA)') + name_of_dataset='combined_15 val dataset (categorical balanced)', + name_of_model='Fine tuned model (combined_15)') report_FT.generate_report() if __name__ == '__main__': - create_ISA_dataset() + # create_astral_dataset() + # create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1) main() # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv') -- GitLab