Skip to content
Snippets Groups Projects
main_fine_tune.py 12.49 KiB
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from dlomix.models import DetectabilityModel
from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict
from dlomix.data import DetectabilityDataset
from datasets import load_dataset, DatasetDict
from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report


def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
    df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
    df_no_flyer['Classes'] = df_no_flyer[classe_type]
    df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
    df_flyer['Classes']=df_flyer[classe_type]
    df_flyer = df_flyer[['Sequences','Classes']]
    #stratified split
    list_train_split=[]
    list_val_split =[]
    total_count = 0
    for cl in [1,2,3]:
        df_class = df_flyer[df_flyer['Classes']==cl]
        class_count = df_class.shape[0]
        list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
        list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
        total_count+=class_count
    total_count=total_count/3
    list_train_split.append(df_no_flyer.iloc[:int(total_count * split[0] * frac_no_fly_train), :])
    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(total_count * split[1] * frac_no_fly_val):, :])

    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle

    df_train['Proteins']=0
    df_val['Proteins'] = 0
    df_train.to_csv('df_preprocessed/df_train_ISA.csv', index=False)
    df_val.to_csv('df_preprocessed/df_val_ISA_multiclass.csv',index=False)

def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral.csv')
    df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
    df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ']
    df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
    df_flyer['Classes']=df_flyer['Classes MaxLFQ']
    df_flyer = df_flyer[['Sequences','Classes']]
    #stratified split
    list_train_split=[]
    list_val_split =[]
    total_count = 0
    for cl in [1,2,3]:
        df_class = df_flyer[df_flyer['Classes']==cl]
        class_count = df_class.shape[0]
        list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
        list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
        total_count += class_count
    total_count = total_count / 3

    list_train_split.append(df_no_flyer.iloc[:int(total_count * split[0] * frac_no_fly_train), :])
    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(total_count * split[1] * frac_no_fly_val):, :])

    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle

    df_train['Proteins']=0
    df_val['Proteins'] = 0
    df_train.to_csv('df_preprocessed/df_train_astral_4.csv', index=False)
    df_val.to_csv('df_preprocessed/df_val_astral_binary_4.csv',index=False)

def create_combine_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
    df_flyer_astral = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_7.csv')
    df_no_flyer_astral = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
    df_no_flyer_astral['Classes'] = df_no_flyer_astral['Classes MaxLFQ']
    df_no_flyer_astral = df_no_flyer_astral[['Sequences', 'Classes']]
    df_flyer_astral['Classes']=df_flyer_astral['Classes MaxLFQ']
    df_flyer_astral = df_flyer_astral[['Sequences','Classes']]

    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
    df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
    df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ']
    df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
    df_flyer['Classes'] = df_flyer['Classes MaxLFQ']
    df_flyer = df_flyer[['Sequences', 'Classes']]


    #stratified split
    list_train_split=[]
    list_val_split =[]
    for cl in [1,2,3]:
        df_class_astral = df_flyer_astral[df_flyer_astral['Classes']==cl]
        class_count_astral = df_class_astral.shape[0]
        df_class_ISA = df_flyer[df_flyer['Classes'] == cl]
        class_count_ISA = df_class_ISA.shape[0]

        list_train_split.append(df_class_astral.iloc[:int(class_count_astral*split[0]),:])
        list_val_split.append(df_class_astral.iloc[int(class_count_astral * split[0]):, :])
        list_train_split.append(df_class_ISA.iloc[:int(class_count_ISA * split[0]), :])
        list_val_split.append(df_class_ISA.iloc[int(class_count_ISA * split[0]):, :])

    list_train_split.append(df_no_flyer_astral.iloc[:int(class_count_astral * split[0] * frac_no_fly_train), :])
    list_val_split.append(df_no_flyer_astral.iloc[df_no_flyer_astral.shape[0]-int(class_count_astral * split[1] * frac_no_fly_val):, :])
    list_train_split.append(df_no_flyer.iloc[:int(class_count_ISA * split[0] * frac_no_fly_train), :])
    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0] - int(class_count_ISA * split[1] * frac_no_fly_val):, :])

    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle

    df_train['Proteins']=0
    df_val['Proteins'] = 0
    df_train.to_csv('df_preprocessed/df_train_combined_7.csv', index=False)
    df_val.to_csv('df_preprocessed/df_val_combined_multiclass_7.csv',index=False)

def density_plot(prediction_path,prediction_path_2,criteria='base'):
    df = pd.read_csv(prediction_path)
    df['actual_metrics']=-df['Non-Flyer']+df[["Weak Flyer", "Strong Flyer",'Intermediate Flyer']].max(axis=1)
    flyer = df[df['Binary Classes']=='Flyer']
    non_flyer = df[df['Binary Classes'] == 'Non-Flyer']

    df2 = pd.read_csv(prediction_path_2)
    df2['actual_metrics'] = -df2['Non-Flyer'] + df2[["Weak Flyer", "Strong Flyer", 'Intermediate Flyer']].max(axis=1)
    flyer2 = df2[df2['Binary Classes'] == 'Flyer']
    non_flyer2 = df2[df2['Binary Classes'] == 'Non-Flyer']


    non_flyer['Flyer'].plot.density(bw_method=0.2, color='blue', linestyle='-', linewidth=2)
    flyer['Flyer'].plot.density(bw_method=0.2, color='red', linestyle='-', linewidth=2)
    non_flyer2['Flyer'].plot.density(bw_method=0.2, color='yellow', linestyle='-', linewidth=2)
    flyer2['Flyer'].plot.density(bw_method=0.2, color='green', linestyle='-', linewidth=2)
    plt.legend(['non-flyer base','flyer base','non-flyer fine tuned','flyer fine tuned'])
    plt.xlabel('Flyer index')
    plt.xlim(0,1)
    plt.show()
    plt.savefig('density_sum.png')
    plt.clf()
    #Weak Flyer,Intermediate Flyer,Strong Flyer

    non_flyer['actual_metrics'].plot.density(bw_method=0.2, color='blue', linestyle='-', linewidth=2)
    flyer['actual_metrics'].plot.density(bw_method=0.2, color='red', linestyle='-', linewidth=2)
    non_flyer2['actual_metrics'].plot.density(bw_method=0.2, color='yellow', linestyle='-', linewidth=2)
    flyer2['actual_metrics'].plot.density(bw_method=0.2, color='green', linestyle='-', linewidth=2)
    plt.legend(['non-flyer base','flyer base','non-flyer fine tuned','flyer fine tuned'])
    plt.xlabel('Flyer index')
    plt.xlim(-1,1)
    plt.show()
    plt.savefig('density_base.png')





def main():
    total_num_classes = len(CLASSES_LABELS)
    num_cells = 64

    load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL'
    fine_tuned_model = DetectabilityModel(num_units=num_cells,
                                          num_clases=total_num_classes)

    # fine_tuned_model.load_weights(load_model_path)



    max_pep_length = 40
    ## Has no impact for prediction
    batch_size = 16

    print('Initialising dataset')
    ## Data init
    fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_combined_4.csv',
                                              val_data_source='df_preprocessed/df_val_astral_multiclass_4.csv',
                                              data_format='csv',
                                              max_seq_len=max_pep_length,
                                              label_column="Classes",
                                              sequence_column="Sequences",
                                              dataset_columns_to_keep=["Proteins"],
                                              batch_size=batch_size,
                                              with_termini=False,
                                              alphabet=aa_to_int_dict)




    # compile the model  with the optimizer and the metrics we want to use.
    callback_FT = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                   mode='min',
                                                   verbose=1,
                                                   patience=5)

    model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined'

    model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,
                                                             monitor='val_loss',
                                                             mode='min',
                                                             verbose=1,
                                                             save_best_only=True,
                                                             save_weights_only=True)
    opti = tf.keras.optimizers.legacy.Adagrad()
    fine_tuned_model.compile(optimizer=opti,
                             loss='SparseCategoricalCrossentropy',
                             metrics='sparse_categorical_accuracy')



    history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
                                              validation_data=fine_tune_data.tensor_val_data,
                                              epochs=150,
                                              callbacks=[callback_FT, model_checkpoint_FT])

    ## Loading best model weights

    # model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data
    # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model

    fine_tuned_model.load_weights(model_save_path_FT)

    predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_val_data)

    # access val dataset and get the Classes column
    test_targets_FT = fine_tune_data["val"]["Classes"]

    # The dataframe needed for the report

    test_data_df_FT = pd.DataFrame(
        {
            "Sequences": fine_tune_data["val"]["_parsed_sequence"],  # get the raw parsed sequences
            "Classes": test_targets_FT,  # get the test targets from above
            "Proteins": fine_tune_data["val"]["Proteins"]  # get the Proteins column from the dataset object
        }
    )

    test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x))

    # Since the detectabiliy report expects the true labels in one-hot encoded format, we expand them here.

    num_classes = np.max(test_targets_FT) + 1
    test_targets_FT_one_hot = np.eye(num_classes)[test_targets_FT]
    test_targets_FT_one_hot.shape, len(test_targets_FT)

    report_FT = DetectabilityReport(test_targets_FT_one_hot,
                                    predictions_FT,
                                    test_data_df_FT,
                                    output_path='./output/report_on_astral_4 (Fine tune model (combined 4) categorical train, categorical val )',
                                    history=history_fine_tuned,
                                    rank_by_prot=True,
                                    threshold=None,
                                    name_of_dataset='astral_4 val dataset (Categorical balanced)',
                                    name_of_model='Fine tune model (combined 4)')

    report_FT.generate_report()

if __name__ == '__main__':
    # create_astral_dataset(frac_no_fly_val=1)
    # create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1)
    # create_ISA_dataset(frac_no_fly_val=1)
    main()
    # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')