diff --git a/dataset_comparison.py b/dataset_comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6376399c99bb808eb698f284395b0fe2503254 --- /dev/null +++ b/dataset_comparison.py @@ -0,0 +1,64 @@ +import pandas as pd +from datasets import load_dataset, DatasetDict + + +df_list =["Wilhelmlab/detectability-proteometools", "Wilhelmlab/detectability-wang","Wilhelmlab/detectability-sinitcyn"] +df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv') +df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv') + +for label_type in ['Classes fragment','Classes precursor', 'Classes MaxLFQ'] : + df_full = pd.concat([df_flyer,df_no_flyer]) + df_size = df_full.shape[0] + nb_no_flyer = df_full[df_full[label_type]==0].shape[0] + nb_weak_flyer = df_full[df_full[label_type] == 1].shape[0] + nb_intermediate_flyer = df_full[df_full[label_type] == 2].shape[0] + nb_strong_flyer = df_full[df_full[label_type] == 3].shape[0] + print('df ISA {} class repartition : No flyer {:.2f}%, Weak flyer {:.2f}%, Intermediate flyer {:.2f}%, Strong flyer {:.2f}%'.format(label_type,100*nb_no_flyer/df_size,100*nb_weak_flyer/df_size,100*nb_intermediate_flyer/df_size,100*nb_strong_flyer/df_size)) + +l_inter_ISA=[] +l_df_hg=[] +for hf_data_name in df_list : + + hf_dataset_split = load_dataset(hf_data_name) + l = [pd.DataFrame(hf_dataset_split[k]) for k in hf_dataset_split.keys()] + df_hg = pd.concat(l) + + df_size = df_hg.shape[0] + nb_no_flyer = df_hg[df_hg['Classes']==0].shape[0] + nb_weak_flyer = df_hg[df_hg['Classes'] == 1].shape[0] + nb_intermediate_flyer = df_hg[df_hg['Classes'] == 2].shape[0] + nb_strong_flyer = df_hg[df_hg['Classes'] == 3].shape[0] + print('df {} class repartition : No flyer {:.2f}%, Weak flyer {:.2f}%, Intermediate flyer {:.2f}%, Strong flyer {:.2f}%'.format(hf_data_name,100*nb_no_flyer/df_size,100*nb_weak_flyer/df_size,100*nb_intermediate_flyer/df_size,100*nb_strong_flyer/df_size)) + + df_common = df_hg.join(df_full.set_index('Sequences'),on='Sequences',how='inner',lsuffix='_hg',rsuffix='_ISA') + size_inter = df_common.shape[0] + same_label = df_common[df_common['Classes']==df_common['Classes MaxLFQ']].shape[0] + l_inter_ISA.append(df_common) + print('Inter with ISA df size : {}, similar label : {:.2f}%'.format(size_inter,100*same_label/size_inter)) + + for df_hg_bis in l_df_hg : + df_common = df_hg.join(df_hg_bis.set_index('Sequences'), on='Sequences', how='inner', lsuffix='_hg', + rsuffix='_hg_bis') + size_inter = df_common.shape[0] + same_label = df_common[df_common['Classes_hg'] == df_common['Classes_hg_bis']] + same_label_size = same_label.shape[0] + cf_matrix = pd.crosstab(df_common['Classes_hg'], df_common['Classes_hg_bis']) + print('Inter with df hg bis df size : {}, similar label : {:.2f}%'.format(size_inter, 100 * same_label_size / size_inter)) + print(cf_matrix) + l_df_hg.append(df_hg) + + + + + + + +# +# target_labels = { +# 0: "Non-Flyer", +# 1: "Weak Flyer", +# 2: "Intermediate Flyer", +# 3: "Strong Flyer", +# } +# binary_labels = {0: "Non-Flyer", 1: "Flyer"} + diff --git a/dataset_extraction.py b/dataset_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b8e1802cebb34fb971800228e5b4076f4535bb --- /dev/null +++ b/dataset_extraction.py @@ -0,0 +1,86 @@ +import pandas as pd + +""" +target_labels = { + 0: "Non-Flyer", + 1: "Weak Flyer", + 2: "Intermediate Flyer", + 3: "Strong Flyer", +} +binary_labels = {0: "Non-Flyer", 1: "Flyer"} +""" + + +def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'): + df = pd.read_excel('ISA_data/250326_gut_microbiome_std_17_proteomes_data_training_detectability.xlsx') + df_non_flyer = pd.read_csv('ISA_data/250422_FASTA_17_proteomes_gut_std_ozyme_+_conta_peptides_digested_filtered.csv') + #No flyer + df_non_flyer = df_non_flyer[df_non_flyer['Cystein ? ']=='Any'] + df_non_flyer = df_non_flyer[df_non_flyer['Miscleavage ?'] == 'Any'] + df_non_flyer = df_non_flyer[df_non_flyer['MaxLFQ'] == 0.0] + df_non_flyer['Sequences'] = df_non_flyer['Peptide'] + df_non_flyer['Proteins'] = df_non_flyer['ProteinID'] + df_non_flyer=df_non_flyer[['Sequences','Proteins']].drop_duplicates() + df_non_flyer['Classes fragment']=0 + df_non_flyer['Classes precursor'] =0 + df_non_flyer['Classes MaxLFQ'] =0 + + + #Flyer + df_filtered = df[df['Proteotypic ?']=='Proteotypic'] + df_filtered = df_filtered[df_filtered['Coverage ']>=coverage_treshold] + df_filtered = df_filtered[df_filtered['Miscleavage ?']=='Any'] + peptide_count=df_filtered.groupby(["Protein.Names"]).size().reset_index(name='counts') + filtered_sequence = peptide_count[peptide_count['counts']>=min_peptide]["Protein.Names"] + df_filtered = df_filtered[df_filtered["Protein.Names"].isin(filtered_sequence.to_list())] + + df1_grouped = df_filtered.groupby("Protein.Names") + dico_final={} + # iterate over each group + for group_name, df_group in df1_grouped: + seq = df_group.sort_values(by=[intensity_col])['Stripped.Sequence'].to_list() + value_frag = df_group.sort_values(by=[intensity_col])[intensity_col].to_list() + value_prec = df_group.sort_values(by=['Precursor.Quantity'])['Precursor.Quantity'].to_list() + value_prec_frag = df_group.sort_values(by=[intensity_col])['Precursor.Quantity'].to_list() + value_maxlfq = df_group.sort_values(by=['MaxLFQ'])['MaxLFQ'].to_list() + value_maxlfq_frag = df_group.sort_values(by=[intensity_col])['MaxLFQ'].to_list() + threshold_weak_flyer_frag = value_frag[int(len(seq) / 3)] + threshold_medium_flyer_frag = value_frag[int(2*len(seq) / 3)] + threshold_weak_flyer_prec = value_prec[int(len(seq) / 3)] + threshold_medium_flyer_prec = value_prec[int(2 * len(seq) / 3)] + threshold_weak_flyer_maxflq = value_maxlfq[int(len(seq) / 3)] + threshold_medium_flyer_maxlfq = value_maxlfq[int(2 * len(seq) / 3)] + prot = df_group['Protein.Group'].to_list()[0] + + for i in range(len(seq)): + if value_frag[i] < threshold_weak_flyer_frag : + label_frag = 1 + elif value_frag[i] < threshold_medium_flyer_frag : + label_frag = 2 + else : + label_frag = 3 + + if value_prec_frag[i] < threshold_weak_flyer_prec : + label_prec = 1 + elif value_prec_frag[i] < threshold_medium_flyer_prec : + label_prec = 2 + else : + label_prec = 3 + + if value_maxlfq_frag[i] < threshold_weak_flyer_maxflq : + label_maxlfq = 1 + elif value_maxlfq_frag[i] < threshold_medium_flyer_maxlfq : + label_maxlfq = 2 + else : + label_maxlfq = 3 + + dico_final[seq[i]] = (prot,label_frag,label_prec,label_maxlfq) + + df_final = pd.DataFrame.from_dict(dico_final, orient='index',columns=['Proteins', 'Classes fragment','Classes precursor', 'Classes MaxLFQ']) + df_final['Sequences']=df_final.index + df_final = df_final.reset_index() + df_final=df_final[['Sequences','Proteins','Classes fragment','Classes precursor', 'Classes MaxLFQ']] + df_final.to_csv(f_name, index=False) + df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage.csv', index=False) +if __name__ == '__main__': + build_dataset( coverage_treshold=20, min_peptide=4, f_name='ISA_data/df_finetune_no_miscleavage.csv') diff --git a/main.py b/main.py index 76475bd01ab5d0f7292deb7b8e803eb9bf6a86f5..08fe03cd7cd0279673f6ef56a72b679c5db865b9 100644 --- a/main.py +++ b/main.py @@ -165,4 +165,4 @@ def main(input_data_path): if __name__ == '__main__': # main() - main('241205_list_test_peptide_detectability.txt') + main('output/241205_list_test_peptide_detectability.txt') diff --git a/main_fine_tune.py b/main_fine_tune.py new file mode 100644 index 0000000000000000000000000000000000000000..e295ebe7693d2ec667a3860d5af2418c99601833 --- /dev/null +++ b/main_fine_tune.py @@ -0,0 +1,185 @@ +import numpy as np +import pandas as pd +import tensorflow as tf +import matplotlib.pyplot as plt +from dlomix.models import DetectabilityModel +from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict +from dlomix.data import DetectabilityDataset +from datasets import load_dataset, DatasetDict +from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report + + +def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=2): + df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv') + df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv') + df_no_flyer['Classes'] = df_no_flyer[classe_type] + df_no_flyer = df_no_flyer[['Sequences', 'Classes']] + df_flyer['Classes']=df_flyer[classe_type] + df_flyer = df_flyer[['Sequences','Classes']] + #stratified split + list_train_split=[] + list_val_split =[] + for cl in [1,2,3]: + df_class = df_flyer[df_flyer['Classes']==cl] + class_count = df_class.shape[0] + list_train_split.append(df_class.iloc[:int(class_count*split[0]),:]) + list_val_split.append(df_class.iloc[int(class_count * split[0]):, :]) + + list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :]) + list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :]) + + df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) + df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) + + df_train['Proteins']=0 + df_val['Proteins'] = 0 + df_train.to_csv('temp_fine_tune_df_train.csv', index=False) + df_val.to_csv('temp_fine_tune_df_val.csv',index=False) + +def density_plot(prediction_path,prediction_path_2,criteria='base'): + df = pd.read_csv(prediction_path) + df['actual_metrics']=-df['Non-Flyer']+df[["Weak Flyer", "Strong Flyer",'Intermediate Flyer']].max(axis=1) + flyer = df[df['Binary Classes']=='Flyer'] + non_flyer = df[df['Binary Classes'] == 'Non-Flyer'] + + df2 = pd.read_csv(prediction_path_2) + df2['actual_metrics'] = -df2['Non-Flyer'] + df2[["Weak Flyer", "Strong Flyer", 'Intermediate Flyer']].max(axis=1) + flyer2 = df2[df2['Binary Classes'] == 'Flyer'] + non_flyer2 = df2[df2['Binary Classes'] == 'Non-Flyer'] + + + non_flyer['Flyer'].plot.density(bw_method=0.2, color='blue', linestyle='-', linewidth=2) + flyer['Flyer'].plot.density(bw_method=0.2, color='red', linestyle='-', linewidth=2) + non_flyer2['Flyer'].plot.density(bw_method=0.2, color='yellow', linestyle='-', linewidth=2) + flyer2['Flyer'].plot.density(bw_method=0.2, color='green', linestyle='-', linewidth=2) + plt.legend(['non-flyer base','flyer base','non-flyer fine tuned','flyer fine tuned']) + plt.xlabel('Flyer index') + plt.xlim(0,1) + plt.show() + plt.savefig('density_sum.png') + plt.clf() + #Weak Flyer,Intermediate Flyer,Strong Flyer + + non_flyer['actual_metrics'].plot.density(bw_method=0.2, color='blue', linestyle='-', linewidth=2) + flyer['actual_metrics'].plot.density(bw_method=0.2, color='red', linestyle='-', linewidth=2) + non_flyer2['actual_metrics'].plot.density(bw_method=0.2, color='yellow', linestyle='-', linewidth=2) + flyer2['actual_metrics'].plot.density(bw_method=0.2, color='green', linestyle='-', linewidth=2) + plt.legend(['non-flyer base','flyer base','non-flyer fine tuned','flyer fine tuned']) + plt.xlabel('Flyer index') + plt.xlim(-1,1) + plt.show() + plt.savefig('density_base.png') + + + + + +def main(): + total_num_classes = len(CLASSES_LABELS) + input_dimension = len(alphabet) + num_cells = 64 + + fine_tuned_model = DetectabilityModel(num_units=num_cells, num_clases=total_num_classes) + load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' + fine_tuned_model = DetectabilityModel(num_units=num_cells, + num_clases=total_num_classes) + + fine_tuned_model.load_weights(load_model_path) + + + + max_pep_length = 40 + ## Has no impact for prediction + batch_size = 16 + + print('Initialising dataset') + ## Data init + fine_tune_data = DetectabilityDataset(data_source='temp_fine_tune_df_train.csv', + val_data_source='temp_fine_tune_df_val.csv', + data_format='csv', + max_seq_len=max_pep_length, + label_column="Classes", + sequence_column="Sequences", + dataset_columns_to_keep=["Proteins"], + batch_size=batch_size, + with_termini=False, + alphabet=aa_to_int_dict) + + + + + # compile the model with the optimizer and the metrics we want to use. + callback_FT = tf.keras.callbacks.EarlyStopping(monitor='val_loss', + mode='min', + verbose=1, + patience=5) + + model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability' + + model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT, + monitor='val_loss', + mode='min', + verbose=1, + save_best_only=True, + save_weights_only=True) + opti = tf.keras.optimizers.legacy.Adagrad() + fine_tuned_model.compile(optimizer=opti, + loss='SparseCategoricalCrossentropy', + metrics='sparse_categorical_accuracy') + + + + history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data, + validation_data=fine_tune_data.tensor_val_data, + epochs=1, + callbacks=[callback_FT, model_checkpoint_FT]) + + ## Loading best model weights + + # model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability' + model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model + + fine_tuned_model.load_weights(model_save_path_FT) + + predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_val_data) + + # access val dataset and get the Classes column + test_targets_FT = fine_tune_data["val"]["Classes"] + + # if needed, the decoded version of the classes can be retrieved by looking up the class names + test_targets_decoded_FT = [CLASSES_LABELS[x] for x in test_targets_FT] + + # The dataframe needed for the report + + test_data_df_FT = pd.DataFrame( + { + "Sequences": fine_tune_data["val"]["_parsed_sequence"], # get the raw parsed sequences + "Classes": test_targets_FT, # get the test targets from above + "Proteins": fine_tune_data["val"]["Proteins"] # get the Proteins column from the dataset object + } + ) + + test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x)) + + # Since the detectabiliy report expects the true labels in one-hot encoded format, we expand them here. + + num_classes = np.max(test_targets_FT) + 1 + test_targets_FT_one_hot = np.eye(num_classes)[test_targets_FT] + test_targets_FT_one_hot.shape, len(test_targets_FT) + + report_FT = DetectabilityReport(test_targets_FT_one_hot, + predictions_FT, + test_data_df_FT, + output_path='./output/report_on_ISA (Base model categorical train, binary val )', + history=history_fine_tuned, + rank_by_prot=True, + threshold=None, + name_of_dataset='ISA val dataset (binary balanced)', + name_of_model='Base model (ISA)') + + report_FT.generate_report() + +if __name__ == '__main__': + create_ISA_dataset() + main() + # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')