-
Schneider Leo authored70c39f3f
main_fine_tune.py 12.49 KiB
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from dlomix.models import DetectabilityModel
from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict
from dlomix.data import DetectabilityDataset
from datasets import load_dataset, DatasetDict
from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report
def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
df_no_flyer['Classes'] = df_no_flyer[classe_type]
df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
df_flyer['Classes']=df_flyer[classe_type]
df_flyer = df_flyer[['Sequences','Classes']]
#stratified split
list_train_split=[]
list_val_split =[]
total_count = 0
for cl in [1,2,3]:
df_class = df_flyer[df_flyer['Classes']==cl]
class_count = df_class.shape[0]
list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
total_count+=class_count
total_count=total_count/3
list_train_split.append(df_no_flyer.iloc[:int(total_count * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(total_count * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
df_train['Proteins']=0
df_val['Proteins'] = 0
df_train.to_csv('df_preprocessed/df_train_ISA.csv', index=False)
df_val.to_csv('df_preprocessed/df_val_ISA_multiclass.csv',index=False)
def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ']
df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
df_flyer['Classes']=df_flyer['Classes MaxLFQ']
df_flyer = df_flyer[['Sequences','Classes']]
#stratified split
list_train_split=[]
list_val_split =[]
total_count = 0
for cl in [1,2,3]:
df_class = df_flyer[df_flyer['Classes']==cl]
class_count = df_class.shape[0]
list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
total_count += class_count
total_count = total_count / 3
list_train_split.append(df_no_flyer.iloc[:int(total_count * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(total_count * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
df_train['Proteins']=0
df_val['Proteins'] = 0
df_train.to_csv('df_preprocessed/df_train_astral_4.csv', index=False)
df_val.to_csv('df_preprocessed/df_val_astral_binary_4.csv',index=False)
def create_combine_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
df_flyer_astral = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_7.csv')
df_no_flyer_astral = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
df_no_flyer_astral['Classes'] = df_no_flyer_astral['Classes MaxLFQ']
df_no_flyer_astral = df_no_flyer_astral[['Sequences', 'Classes']]
df_flyer_astral['Classes']=df_flyer_astral['Classes MaxLFQ']
df_flyer_astral = df_flyer_astral[['Sequences','Classes']]
df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ']
df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
df_flyer['Classes'] = df_flyer['Classes MaxLFQ']
df_flyer = df_flyer[['Sequences', 'Classes']]
#stratified split
list_train_split=[]
list_val_split =[]
for cl in [1,2,3]:
df_class_astral = df_flyer_astral[df_flyer_astral['Classes']==cl]
class_count_astral = df_class_astral.shape[0]
df_class_ISA = df_flyer[df_flyer['Classes'] == cl]
class_count_ISA = df_class_ISA.shape[0]
list_train_split.append(df_class_astral.iloc[:int(class_count_astral*split[0]),:])
list_val_split.append(df_class_astral.iloc[int(class_count_astral * split[0]):, :])
list_train_split.append(df_class_ISA.iloc[:int(class_count_ISA * split[0]), :])
list_val_split.append(df_class_ISA.iloc[int(class_count_ISA * split[0]):, :])
list_train_split.append(df_no_flyer_astral.iloc[:int(class_count_astral * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer_astral.iloc[df_no_flyer_astral.shape[0]-int(class_count_astral * split[1] * frac_no_fly_val):, :])
list_train_split.append(df_no_flyer.iloc[:int(class_count_ISA * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0] - int(class_count_ISA * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
df_train['Proteins']=0
df_val['Proteins'] = 0
df_train.to_csv('df_preprocessed/df_train_combined_7.csv', index=False)
df_val.to_csv('df_preprocessed/df_val_combined_multiclass_7.csv',index=False)
def density_plot(prediction_path,prediction_path_2,criteria='base'):
df = pd.read_csv(prediction_path)
df['actual_metrics']=-df['Non-Flyer']+df[["Weak Flyer", "Strong Flyer",'Intermediate Flyer']].max(axis=1)
flyer = df[df['Binary Classes']=='Flyer']
non_flyer = df[df['Binary Classes'] == 'Non-Flyer']
df2 = pd.read_csv(prediction_path_2)
df2['actual_metrics'] = -df2['Non-Flyer'] + df2[["Weak Flyer", "Strong Flyer", 'Intermediate Flyer']].max(axis=1)
flyer2 = df2[df2['Binary Classes'] == 'Flyer']
non_flyer2 = df2[df2['Binary Classes'] == 'Non-Flyer']
non_flyer['Flyer'].plot.density(bw_method=0.2, color='blue', linestyle='-', linewidth=2)
flyer['Flyer'].plot.density(bw_method=0.2, color='red', linestyle='-', linewidth=2)
non_flyer2['Flyer'].plot.density(bw_method=0.2, color='yellow', linestyle='-', linewidth=2)
flyer2['Flyer'].plot.density(bw_method=0.2, color='green', linestyle='-', linewidth=2)
plt.legend(['non-flyer base','flyer base','non-flyer fine tuned','flyer fine tuned'])
plt.xlabel('Flyer index')
plt.xlim(0,1)
plt.show()
plt.savefig('density_sum.png')
plt.clf()
#Weak Flyer,Intermediate Flyer,Strong Flyer
non_flyer['actual_metrics'].plot.density(bw_method=0.2, color='blue', linestyle='-', linewidth=2)
flyer['actual_metrics'].plot.density(bw_method=0.2, color='red', linestyle='-', linewidth=2)
non_flyer2['actual_metrics'].plot.density(bw_method=0.2, color='yellow', linestyle='-', linewidth=2)
flyer2['actual_metrics'].plot.density(bw_method=0.2, color='green', linestyle='-', linewidth=2)
plt.legend(['non-flyer base','flyer base','non-flyer fine tuned','flyer fine tuned'])
plt.xlabel('Flyer index')
plt.xlim(-1,1)
plt.show()
plt.savefig('density_base.png')
def main():
total_num_classes = len(CLASSES_LABELS)
num_cells = 64
load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL'
fine_tuned_model = DetectabilityModel(num_units=num_cells,
num_clases=total_num_classes)
# fine_tuned_model.load_weights(load_model_path)
max_pep_length = 40
## Has no impact for prediction
batch_size = 16
print('Initialising dataset')
## Data init
fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_combined_4.csv',
val_data_source='df_preprocessed/df_val_astral_multiclass_4.csv',
data_format='csv',
max_seq_len=max_pep_length,
label_column="Classes",
sequence_column="Sequences",
dataset_columns_to_keep=["Proteins"],
batch_size=batch_size,
with_termini=False,
alphabet=aa_to_int_dict)
# compile the model with the optimizer and the metrics we want to use.
callback_FT = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
mode='min',
verbose=1,
patience=5)
model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined'
model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,
monitor='val_loss',
mode='min',
verbose=1,
save_best_only=True,
save_weights_only=True)
opti = tf.keras.optimizers.legacy.Adagrad()
fine_tuned_model.compile(optimizer=opti,
loss='SparseCategoricalCrossentropy',
metrics='sparse_categorical_accuracy')
history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
validation_data=fine_tune_data.tensor_val_data,
epochs=150,
callbacks=[callback_FT, model_checkpoint_FT])
## Loading best model weights
# model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data
# model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
fine_tuned_model.load_weights(model_save_path_FT)
predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_val_data)
# access val dataset and get the Classes column
test_targets_FT = fine_tune_data["val"]["Classes"]
# The dataframe needed for the report
test_data_df_FT = pd.DataFrame(
{
"Sequences": fine_tune_data["val"]["_parsed_sequence"], # get the raw parsed sequences
"Classes": test_targets_FT, # get the test targets from above
"Proteins": fine_tune_data["val"]["Proteins"] # get the Proteins column from the dataset object
}
)
test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x))
# Since the detectabiliy report expects the true labels in one-hot encoded format, we expand them here.
num_classes = np.max(test_targets_FT) + 1
test_targets_FT_one_hot = np.eye(num_classes)[test_targets_FT]
test_targets_FT_one_hot.shape, len(test_targets_FT)
report_FT = DetectabilityReport(test_targets_FT_one_hot,
predictions_FT,
test_data_df_FT,
output_path='./output/report_on_astral_4 (Fine tune model (combined 4) categorical train, categorical val )',
history=history_fine_tuned,
rank_by_prot=True,
threshold=None,
name_of_dataset='astral_4 val dataset (Categorical balanced)',
name_of_model='Fine tune model (combined 4)')
report_FT.generate_report()
if __name__ == '__main__':
# create_astral_dataset(frac_no_fly_val=1)
# create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1)
# create_ISA_dataset(frac_no_fly_val=1)
main()
# density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')