From 0a226689532fe8ea9e6a1b680c15d8b7db80eace Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Wed, 14 May 2025 16:07:13 +0200
Subject: [PATCH] binary pretraining

---
 binary_training.py | 235 +++++++++++++++++++++++++++++++++++++++++++++
 dummy.csv          |   8 +-
 main_fine_tune.py  |  43 +++++----
 3 files changed, 265 insertions(+), 21 deletions(-)
 create mode 100644 binary_training.py

diff --git a/binary_training.py b/binary_training.py
new file mode 100644
index 0000000..f9a4a64
--- /dev/null
+++ b/binary_training.py
@@ -0,0 +1,235 @@
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+import matplotlib.pyplot as plt
+from dlomix.models import DetectabilityModel
+from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict
+from dlomix.data import DetectabilityDataset
+from os.path import join, exists
+from os import makedirs
+from sklearn.metrics import ConfusionMatrixDisplay, auc, confusion_matrix, roc_curve
+
+def create_ISA_binary_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
+    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
+    df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
+    df_no_flyer['Classes'] = 0
+    df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
+    df_flyer['Classes'] = 1
+    df_flyer = df_flyer[['Sequences','Classes']]
+    #stratified split
+    list_train_split=[]
+    list_val_split =[]
+
+    flyer_count = df_flyer.shape[0] #rebalanced flyer and no flyer
+    list_train_split.append(df_flyer.iloc[:int(flyer_count*split[0]),:])
+    list_val_split.append(df_flyer.iloc[int(flyer_count * split[0]):, :])
+
+    list_train_split.append(df_no_flyer.iloc[:int(flyer_count * split[0] * frac_no_fly_train), :])
+    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(flyer_count * split[1] * frac_no_fly_val):, :])
+
+    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
+    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
+
+    df_train['Proteins'] = 0
+    df_val['Proteins'] = 0
+    df_train.to_csv('df_preprocessed/df_train_ISA_binary.csv', index=False)
+    df_val.to_csv('df_preprocessed/df_val_ISA_binary.csv',index=False)
+
+def create_astral_binary_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
+    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral.csv')
+    df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
+    df_no_flyer['Classes'] = 0
+    df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
+    df_flyer['Classes'] = 1
+    df_flyer = df_flyer[['Sequences','Classes']]
+    #stratified split
+    list_train_split=[]
+    list_val_split =[]
+
+    flyer_count = df_flyer.shape[0] #rebalanced flyer and no flyer
+    list_train_split.append(df_flyer.iloc[:int(flyer_count*split[0]),:])
+    list_val_split.append(df_flyer.iloc[int(flyer_count * split[0]):, :])
+
+    list_train_split.append(df_no_flyer.iloc[:int(flyer_count * split[0] * frac_no_fly_train), :])
+    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(flyer_count * split[1] * frac_no_fly_val):, :])
+
+    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
+    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
+
+    df_train['Proteins'] = 0
+    df_val['Proteins'] = 0
+    df_train.to_csv('df_preprocessed/df_train_astral_binary.csv', index=False)
+    df_val.to_csv('df_preprocessed/df_val_astral_binary.csv',index=False)
+
+
+def main(epoch):
+    total_num_classes = 2
+    num_cells = 64
+
+    load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL'
+    fine_tuned_model = DetectabilityModel(num_units=num_cells,
+                                          num_clases=2)
+    fine_tuned_model.build((None, 40))
+
+    base_arch  = DetectabilityModel(num_units=num_cells,
+                                          num_clases=4)
+    base_arch.load_weights(load_model_path)
+
+    #partially loading pretrained weights (multiclass training)
+    base_arch.build((None, 40))
+    weights_list = base_arch.get_weights()
+    weights_list[-1] = np.array([0.,0.],dtype=np.float32)
+    weights_list[-2] = np.zeros((128,2),dtype=np.float32)
+    fine_tuned_model.set_weights(weights_list)
+
+
+
+    max_pep_length = 40
+    ## Has no impact for prediction
+    batch_size = 16
+
+    print('Initialising dataset')
+    ## Data init
+    fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_astral_binary.csv',
+                                              val_data_source='df_preprocessed/df_val_astral_binary.csv',
+                                              data_format='csv',
+                                              max_seq_len=max_pep_length,
+                                              label_column="Classes",
+                                              sequence_column="Sequences",
+                                              dataset_columns_to_keep=["Proteins"],
+                                              batch_size=batch_size,
+                                              with_termini=False,
+                                              alphabet=aa_to_int_dict)
+
+
+
+
+    # compile the model  with the optimizer and the metrics we want to use.
+    callback_FT = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
+                                                   mode='min',
+                                                   verbose=1,
+                                                   patience=5)
+
+    model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined'
+
+    model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,
+                                                             monitor='val_loss',
+                                                             mode='min',
+                                                             verbose=1,
+                                                             save_best_only=True,
+                                                             save_weights_only=True)
+    opti = tf.keras.optimizers.legacy.Adagrad()
+    fine_tuned_model.compile(optimizer=opti,
+                             loss='SparseCategoricalCrossentropy',
+                             metrics='sparse_categorical_accuracy')
+
+
+
+    history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
+                                              validation_data=fine_tune_data.tensor_val_data,
+                                              epochs=epoch,
+                                              callbacks=[callback_FT, model_checkpoint_FT])
+    ## Loading best model weights
+
+    model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data
+    # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
+
+    fine_tuned_model.load_weights(model_save_path_FT)
+
+    predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_val_data)
+
+    # access val dataset and get the Classes column
+    test_targets_FT = fine_tune_data["val"]["Classes"]
+
+    # The dataframe needed for the report
+
+    test_data_df_FT = pd.DataFrame(
+        {
+            "Sequences": fine_tune_data["val"]["_parsed_sequence"],  # get the raw parsed sequences
+            "Classes": test_targets_FT,  # get the test targets from above
+            "Proteins": fine_tune_data["val"]["Proteins"],  # get the Proteins column from the dataset object
+            "Prob non flyer": predictions_FT[:,0],
+            "Prob flyer": predictions_FT[:, 1],
+            "Predicted classes" : np.argmax(predictions_FT,axis=1)
+        }
+    )
+
+
+
+    test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x))
+
+    return test_data_df_FT, history_fine_tuned
+
+def plot_and_save_metrics(history, base_path):
+    history_dict = history.history
+    metrics = history_dict.keys()
+    metrics = filter(lambda x: not x.startswith(tuple(["val_", "_"])), metrics)
+
+    if not exists(base_path):
+        makedirs(base_path)
+
+    for metric_name in metrics:
+        plt.plot(history_dict[metric_name])
+        plt.plot(history_dict["val_" + metric_name])
+        plt.title(metric_name, fontsize=10)  # Modified Original plt.title(metric_name)
+        plt.ylabel(metric_name)
+        plt.xlabel("epoch")
+        plt.legend(["train", "val"], loc="best")
+        save_path = join(base_path, metric_name)
+        plt.savefig(
+            save_path, bbox_inches="tight", dpi=90
+        )  # Modification Original plt.savefig(save_path)
+        plt.close()
+
+def plot_confusion_matrix(df, base_path):
+    conf_matrix = confusion_matrix(
+        df["Classes"],
+        df["Predicted classes"],
+    )
+
+    if not exists(base_path):
+        makedirs(base_path)
+
+    conf_matrix_disp = ConfusionMatrixDisplay(
+        confusion_matrix=conf_matrix, display_labels=["Non-Flyer", "Flyer"]
+    )
+    fig, ax = plt.subplots()
+    conf_matrix_disp.plot(xticks_rotation=45, ax=ax)
+    plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11)
+    save_path = join(base_path, "confusion_matrix_binary"
+    )
+    plt.savefig(save_path, bbox_inches="tight", dpi=80)
+    plt.close()
+
+def plot_roc(df,base_path):
+    fpr, tpr, thresholds = roc_curve(
+        np.array(df["Classes"]),
+        np.array(df["Prob flyer"]),
+    )
+    AUC_score = auc(fpr, tpr)
+
+    # create ROC curve
+
+    plt.plot(fpr, tpr, label="ROC curve of (area = {})".format(AUC_score))
+    plt.title(
+        "Receiver operating characteristic curve (Binary classification)",
+        y=1.04,
+        fontsize=10,
+    )
+
+    plt.ylabel("True Positive Rate")
+    plt.xlabel("False Positive Rate")
+    save_path = join(
+        base_path, "ROC_curve_binary_classification"
+    )
+
+    plt.savefig(save_path, bbox_inches="tight", dpi=90)
+    plt.close()
+
+if __name__ == '__main__':
+    create_astral_binary_dataset()
+    test_data_df_FT, history = main(epoch=150)
+    base_path = 'output/binary_astral'
+    plot_and_save_metrics(history,base_path)
+    plot_confusion_matrix(test_data_df_FT,base_path)
+    plot_roc(test_data_df_FT,base_path)
\ No newline at end of file
diff --git a/dummy.csv b/dummy.csv
index 63e4360..7f0c4e5 100644
--- a/dummy.csv
+++ b/dummy.csv
@@ -1,2 +1,8 @@
 Sequences,Classes,Proteins
-IVDDLSALTVLEASELSK,0,0
\ No newline at end of file
+IVDDLSALTVLEASELSK,0,0
+IVDDLSALTVLEASELSK,1,0
+IVDDLSALTVLEASELSK,1,0
+IVDDLSALTVLEASELSK,0,0
+IVDDLSALTVLEASELSK,0,0
+IVDDLSALTVLEASELSK,0,0
+IVDDLSALTVLEASELSK,0,0
diff --git a/main_fine_tune.py b/main_fine_tune.py
index e145f6d..d4ad775 100644
--- a/main_fine_tune.py
+++ b/main_fine_tune.py
@@ -19,14 +19,16 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0
     #stratified split
     list_train_split=[]
     list_val_split =[]
+    total_count = 0
     for cl in [1,2,3]:
         df_class = df_flyer[df_flyer['Classes']==cl]
         class_count = df_class.shape[0]
         list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
         list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
-
-    list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :])
-    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :])
+        total_count+=class_count
+    total_count=total_count/3
+    list_train_split.append(df_no_flyer.iloc[:int(total_count * split[0] * frac_no_fly_train), :])
+    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(total_count * split[1] * frac_no_fly_val):, :])
 
     df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
     df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
@@ -37,7 +39,7 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0
     df_val.to_csv('df_preprocessed/df_val_ISA_multiclass.csv',index=False)
 
 def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
-    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv')
+    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral.csv')
     df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
     df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ']
     df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
@@ -46,22 +48,25 @@ def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1
     #stratified split
     list_train_split=[]
     list_val_split =[]
+    total_count = 0
     for cl in [1,2,3]:
         df_class = df_flyer[df_flyer['Classes']==cl]
         class_count = df_class.shape[0]
         list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
         list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
+        total_count += class_count
+    total_count = total_count / 3
 
-    list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :])
-    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :])
+    list_train_split.append(df_no_flyer.iloc[:int(total_count * split[0] * frac_no_fly_train), :])
+    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(total_count * split[1] * frac_no_fly_val):, :])
 
     df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
     df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
 
     df_train['Proteins']=0
     df_val['Proteins'] = 0
-    df_train.to_csv('df_preprocessed/df_train_astral_15.csv', index=False)
-    df_val.to_csv('df_preprocessed/df_val_astral_multiclass_15.csv',index=False)
+    df_train.to_csv('df_preprocessed/df_train_astral_4.csv', index=False)
+    df_val.to_csv('df_preprocessed/df_val_astral_binary_4.csv',index=False)
 
 def create_combine_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
     df_flyer_astral = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_7.csv')
@@ -152,7 +157,7 @@ def main():
     fine_tuned_model = DetectabilityModel(num_units=num_cells,
                                           num_clases=total_num_classes)
 
-    fine_tuned_model.load_weights(load_model_path)
+    # fine_tuned_model.load_weights(load_model_path)
 
 
 
@@ -162,8 +167,8 @@ def main():
 
     print('Initialising dataset')
     ## Data init
-    fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_combined_15.csv',
-                                              val_data_source='df_preprocessed/df_val_combined_multiclass_15.csv',
+    fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_astral_4.csv',
+                                              val_data_source='df_preprocessed/df_val_astral_multiclass_4.csv',
                                               data_format='csv',
                                               max_seq_len=max_pep_length,
                                               label_column="Classes",
@@ -199,12 +204,12 @@ def main():
 
     history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
                                               validation_data=fine_tune_data.tensor_val_data,
-                                              epochs=150,
+                                              epochs=450,
                                               callbacks=[callback_FT, model_checkpoint_FT])
 
     ## Loading best model weights
 
-    model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data
+    # model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data
     # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
 
     fine_tuned_model.load_weights(model_save_path_FT)
@@ -214,9 +219,6 @@ def main():
     # access val dataset and get the Classes column
     test_targets_FT = fine_tune_data["val"]["Classes"]
 
-    # if needed, the decoded version of the classes can be retrieved by looking up the class names
-    test_targets_decoded_FT = [CLASSES_LABELS[x] for x in test_targets_FT]
-
     # The dataframe needed for the report
 
     test_data_df_FT = pd.DataFrame(
@@ -238,17 +240,18 @@ def main():
     report_FT = DetectabilityReport(test_targets_FT_one_hot,
                                     predictions_FT,
                                     test_data_df_FT,
-                                    output_path='./output/report_on_combined_15 (Fine tuned model (combined_10) categorical train, categorical val)',
+                                    output_path='./output/report_on_astral_4 (from scratch model categorical train, categorical val )',
                                     history=history_fine_tuned,
                                     rank_by_prot=True,
                                     threshold=None,
-                                    name_of_dataset='combined_15 val dataset (categorical balanced)',
-                                    name_of_model='Fine tuned model (combined_15)')
+                                    name_of_dataset='astral_4 val dataset (Categorical balanced)',
+                                    name_of_model='From scratch model')
 
     report_FT.generate_report()
 
 if __name__ == '__main__':
-    # create_astral_dataset()
+    create_astral_dataset(frac_no_fly_val=1)
     # create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1)
+    create_ISA_dataset(frac_no_fly_val=1)
     main()
     # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')
-- 
GitLab