From 0a226689532fe8ea9e6a1b680c15d8b7db80eace Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Wed, 14 May 2025 16:07:13 +0200 Subject: [PATCH] binary pretraining --- binary_training.py | 235 +++++++++++++++++++++++++++++++++++++++++++++ dummy.csv | 8 +- main_fine_tune.py | 43 +++++---- 3 files changed, 265 insertions(+), 21 deletions(-) create mode 100644 binary_training.py diff --git a/binary_training.py b/binary_training.py new file mode 100644 index 0000000..f9a4a64 --- /dev/null +++ b/binary_training.py @@ -0,0 +1,235 @@ +import numpy as np +import pandas as pd +import tensorflow as tf +import matplotlib.pyplot as plt +from dlomix.models import DetectabilityModel +from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict +from dlomix.data import DetectabilityDataset +from os.path import join, exists +from os import makedirs +from sklearn.metrics import ConfusionMatrixDisplay, auc, confusion_matrix, roc_curve + +def create_ISA_binary_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1): + df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv') + df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv') + df_no_flyer['Classes'] = 0 + df_no_flyer = df_no_flyer[['Sequences', 'Classes']] + df_flyer['Classes'] = 1 + df_flyer = df_flyer[['Sequences','Classes']] + #stratified split + list_train_split=[] + list_val_split =[] + + flyer_count = df_flyer.shape[0] #rebalanced flyer and no flyer + list_train_split.append(df_flyer.iloc[:int(flyer_count*split[0]),:]) + list_val_split.append(df_flyer.iloc[int(flyer_count * split[0]):, :]) + + list_train_split.append(df_no_flyer.iloc[:int(flyer_count * split[0] * frac_no_fly_train), :]) + list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(flyer_count * split[1] * frac_no_fly_val):, :]) + + df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle + df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle + + df_train['Proteins'] = 0 + df_val['Proteins'] = 0 + df_train.to_csv('df_preprocessed/df_train_ISA_binary.csv', index=False) + df_val.to_csv('df_preprocessed/df_val_ISA_binary.csv',index=False) + +def create_astral_binary_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1): + df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral.csv') + df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv') + df_no_flyer['Classes'] = 0 + df_no_flyer = df_no_flyer[['Sequences', 'Classes']] + df_flyer['Classes'] = 1 + df_flyer = df_flyer[['Sequences','Classes']] + #stratified split + list_train_split=[] + list_val_split =[] + + flyer_count = df_flyer.shape[0] #rebalanced flyer and no flyer + list_train_split.append(df_flyer.iloc[:int(flyer_count*split[0]),:]) + list_val_split.append(df_flyer.iloc[int(flyer_count * split[0]):, :]) + + list_train_split.append(df_no_flyer.iloc[:int(flyer_count * split[0] * frac_no_fly_train), :]) + list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(flyer_count * split[1] * frac_no_fly_val):, :]) + + df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle + df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle + + df_train['Proteins'] = 0 + df_val['Proteins'] = 0 + df_train.to_csv('df_preprocessed/df_train_astral_binary.csv', index=False) + df_val.to_csv('df_preprocessed/df_val_astral_binary.csv',index=False) + + +def main(epoch): + total_num_classes = 2 + num_cells = 64 + + load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' + fine_tuned_model = DetectabilityModel(num_units=num_cells, + num_clases=2) + fine_tuned_model.build((None, 40)) + + base_arch = DetectabilityModel(num_units=num_cells, + num_clases=4) + base_arch.load_weights(load_model_path) + + #partially loading pretrained weights (multiclass training) + base_arch.build((None, 40)) + weights_list = base_arch.get_weights() + weights_list[-1] = np.array([0.,0.],dtype=np.float32) + weights_list[-2] = np.zeros((128,2),dtype=np.float32) + fine_tuned_model.set_weights(weights_list) + + + + max_pep_length = 40 + ## Has no impact for prediction + batch_size = 16 + + print('Initialising dataset') + ## Data init + fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_astral_binary.csv', + val_data_source='df_preprocessed/df_val_astral_binary.csv', + data_format='csv', + max_seq_len=max_pep_length, + label_column="Classes", + sequence_column="Sequences", + dataset_columns_to_keep=["Proteins"], + batch_size=batch_size, + with_termini=False, + alphabet=aa_to_int_dict) + + + + + # compile the model with the optimizer and the metrics we want to use. + callback_FT = tf.keras.callbacks.EarlyStopping(monitor='val_loss', + mode='min', + verbose=1, + patience=5) + + model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' + + model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT, + monitor='val_loss', + mode='min', + verbose=1, + save_best_only=True, + save_weights_only=True) + opti = tf.keras.optimizers.legacy.Adagrad() + fine_tuned_model.compile(optimizer=opti, + loss='SparseCategoricalCrossentropy', + metrics='sparse_categorical_accuracy') + + + + history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data, + validation_data=fine_tune_data.tensor_val_data, + epochs=epoch, + callbacks=[callback_FT, model_checkpoint_FT]) + ## Loading best model weights + + model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data + # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model + + fine_tuned_model.load_weights(model_save_path_FT) + + predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_val_data) + + # access val dataset and get the Classes column + test_targets_FT = fine_tune_data["val"]["Classes"] + + # The dataframe needed for the report + + test_data_df_FT = pd.DataFrame( + { + "Sequences": fine_tune_data["val"]["_parsed_sequence"], # get the raw parsed sequences + "Classes": test_targets_FT, # get the test targets from above + "Proteins": fine_tune_data["val"]["Proteins"], # get the Proteins column from the dataset object + "Prob non flyer": predictions_FT[:,0], + "Prob flyer": predictions_FT[:, 1], + "Predicted classes" : np.argmax(predictions_FT,axis=1) + } + ) + + + + test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x)) + + return test_data_df_FT, history_fine_tuned + +def plot_and_save_metrics(history, base_path): + history_dict = history.history + metrics = history_dict.keys() + metrics = filter(lambda x: not x.startswith(tuple(["val_", "_"])), metrics) + + if not exists(base_path): + makedirs(base_path) + + for metric_name in metrics: + plt.plot(history_dict[metric_name]) + plt.plot(history_dict["val_" + metric_name]) + plt.title(metric_name, fontsize=10) # Modified Original plt.title(metric_name) + plt.ylabel(metric_name) + plt.xlabel("epoch") + plt.legend(["train", "val"], loc="best") + save_path = join(base_path, metric_name) + plt.savefig( + save_path, bbox_inches="tight", dpi=90 + ) # Modification Original plt.savefig(save_path) + plt.close() + +def plot_confusion_matrix(df, base_path): + conf_matrix = confusion_matrix( + df["Classes"], + df["Predicted classes"], + ) + + if not exists(base_path): + makedirs(base_path) + + conf_matrix_disp = ConfusionMatrixDisplay( + confusion_matrix=conf_matrix, display_labels=["Non-Flyer", "Flyer"] + ) + fig, ax = plt.subplots() + conf_matrix_disp.plot(xticks_rotation=45, ax=ax) + plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11) + save_path = join(base_path, "confusion_matrix_binary" + ) + plt.savefig(save_path, bbox_inches="tight", dpi=80) + plt.close() + +def plot_roc(df,base_path): + fpr, tpr, thresholds = roc_curve( + np.array(df["Classes"]), + np.array(df["Prob flyer"]), + ) + AUC_score = auc(fpr, tpr) + + # create ROC curve + + plt.plot(fpr, tpr, label="ROC curve of (area = {})".format(AUC_score)) + plt.title( + "Receiver operating characteristic curve (Binary classification)", + y=1.04, + fontsize=10, + ) + + plt.ylabel("True Positive Rate") + plt.xlabel("False Positive Rate") + save_path = join( + base_path, "ROC_curve_binary_classification" + ) + + plt.savefig(save_path, bbox_inches="tight", dpi=90) + plt.close() + +if __name__ == '__main__': + create_astral_binary_dataset() + test_data_df_FT, history = main(epoch=150) + base_path = 'output/binary_astral' + plot_and_save_metrics(history,base_path) + plot_confusion_matrix(test_data_df_FT,base_path) + plot_roc(test_data_df_FT,base_path) \ No newline at end of file diff --git a/dummy.csv b/dummy.csv index 63e4360..7f0c4e5 100644 --- a/dummy.csv +++ b/dummy.csv @@ -1,2 +1,8 @@ Sequences,Classes,Proteins -IVDDLSALTVLEASELSK,0,0 \ No newline at end of file +IVDDLSALTVLEASELSK,0,0 +IVDDLSALTVLEASELSK,1,0 +IVDDLSALTVLEASELSK,1,0 +IVDDLSALTVLEASELSK,0,0 +IVDDLSALTVLEASELSK,0,0 +IVDDLSALTVLEASELSK,0,0 +IVDDLSALTVLEASELSK,0,0 diff --git a/main_fine_tune.py b/main_fine_tune.py index e145f6d..d4ad775 100644 --- a/main_fine_tune.py +++ b/main_fine_tune.py @@ -19,14 +19,16 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0 #stratified split list_train_split=[] list_val_split =[] + total_count = 0 for cl in [1,2,3]: df_class = df_flyer[df_flyer['Classes']==cl] class_count = df_class.shape[0] list_train_split.append(df_class.iloc[:int(class_count*split[0]),:]) list_val_split.append(df_class.iloc[int(class_count * split[0]):, :]) - - list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :]) - list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :]) + total_count+=class_count + total_count=total_count/3 + list_train_split.append(df_no_flyer.iloc[:int(total_count * split[0] * frac_no_fly_train), :]) + list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(total_count * split[1] * frac_no_fly_val):, :]) df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle @@ -37,7 +39,7 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0 df_val.to_csv('df_preprocessed/df_val_ISA_multiclass.csv',index=False) def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1): - df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv') + df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral.csv') df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv') df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ'] df_no_flyer = df_no_flyer[['Sequences', 'Classes']] @@ -46,22 +48,25 @@ def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1 #stratified split list_train_split=[] list_val_split =[] + total_count = 0 for cl in [1,2,3]: df_class = df_flyer[df_flyer['Classes']==cl] class_count = df_class.shape[0] list_train_split.append(df_class.iloc[:int(class_count*split[0]),:]) list_val_split.append(df_class.iloc[int(class_count * split[0]):, :]) + total_count += class_count + total_count = total_count / 3 - list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :]) - list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :]) + list_train_split.append(df_no_flyer.iloc[:int(total_count * split[0] * frac_no_fly_train), :]) + list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(total_count * split[1] * frac_no_fly_val):, :]) df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle df_train['Proteins']=0 df_val['Proteins'] = 0 - df_train.to_csv('df_preprocessed/df_train_astral_15.csv', index=False) - df_val.to_csv('df_preprocessed/df_val_astral_multiclass_15.csv',index=False) + df_train.to_csv('df_preprocessed/df_train_astral_4.csv', index=False) + df_val.to_csv('df_preprocessed/df_val_astral_binary_4.csv',index=False) def create_combine_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1): df_flyer_astral = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_7.csv') @@ -152,7 +157,7 @@ def main(): fine_tuned_model = DetectabilityModel(num_units=num_cells, num_clases=total_num_classes) - fine_tuned_model.load_weights(load_model_path) + # fine_tuned_model.load_weights(load_model_path) @@ -162,8 +167,8 @@ def main(): print('Initialising dataset') ## Data init - fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_combined_15.csv', - val_data_source='df_preprocessed/df_val_combined_multiclass_15.csv', + fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_astral_4.csv', + val_data_source='df_preprocessed/df_val_astral_multiclass_4.csv', data_format='csv', max_seq_len=max_pep_length, label_column="Classes", @@ -199,12 +204,12 @@ def main(): history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data, validation_data=fine_tune_data.tensor_val_data, - epochs=150, + epochs=450, callbacks=[callback_FT, model_checkpoint_FT]) ## Loading best model weights - model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data + # model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model fine_tuned_model.load_weights(model_save_path_FT) @@ -214,9 +219,6 @@ def main(): # access val dataset and get the Classes column test_targets_FT = fine_tune_data["val"]["Classes"] - # if needed, the decoded version of the classes can be retrieved by looking up the class names - test_targets_decoded_FT = [CLASSES_LABELS[x] for x in test_targets_FT] - # The dataframe needed for the report test_data_df_FT = pd.DataFrame( @@ -238,17 +240,18 @@ def main(): report_FT = DetectabilityReport(test_targets_FT_one_hot, predictions_FT, test_data_df_FT, - output_path='./output/report_on_combined_15 (Fine tuned model (combined_10) categorical train, categorical val)', + output_path='./output/report_on_astral_4 (from scratch model categorical train, categorical val )', history=history_fine_tuned, rank_by_prot=True, threshold=None, - name_of_dataset='combined_15 val dataset (categorical balanced)', - name_of_model='Fine tuned model (combined_15)') + name_of_dataset='astral_4 val dataset (Categorical balanced)', + name_of_model='From scratch model') report_FT.generate_report() if __name__ == '__main__': - # create_astral_dataset() + create_astral_dataset(frac_no_fly_val=1) # create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1) + create_ISA_dataset(frac_no_fly_val=1) main() # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv') -- GitLab