-
Schneider Leo authored7537263a
main_fine_tune.py 8.53 KiB
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from dlomix.models import DetectabilityModel
from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict
from dlomix.data import DetectabilityDataset
from datasets import load_dataset, DatasetDict
from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report
def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=2):
df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
df_no_flyer['Classes'] = df_no_flyer[classe_type]
df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
df_flyer['Classes']=df_flyer[classe_type]
df_flyer = df_flyer[['Sequences','Classes']]
#stratified split
list_train_split=[]
list_val_split =[]
for cl in [1,2,3]:
df_class = df_flyer[df_flyer['Classes']==cl]
class_count = df_class.shape[0]
list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed)
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed)
df_train['Proteins']=0
df_val['Proteins'] = 0
df_train.to_csv('temp_fine_tune_df_train.csv', index=False)
df_val.to_csv('temp_fine_tune_df_val.csv',index=False)
def density_plot(prediction_path,prediction_path_2,criteria='base'):
df = pd.read_csv(prediction_path)
df['actual_metrics']=-df['Non-Flyer']+df[["Weak Flyer", "Strong Flyer",'Intermediate Flyer']].max(axis=1)
flyer = df[df['Binary Classes']=='Flyer']
non_flyer = df[df['Binary Classes'] == 'Non-Flyer']
df2 = pd.read_csv(prediction_path_2)
df2['actual_metrics'] = -df2['Non-Flyer'] + df2[["Weak Flyer", "Strong Flyer", 'Intermediate Flyer']].max(axis=1)
flyer2 = df2[df2['Binary Classes'] == 'Flyer']
non_flyer2 = df2[df2['Binary Classes'] == 'Non-Flyer']
non_flyer['Flyer'].plot.density(bw_method=0.2, color='blue', linestyle='-', linewidth=2)
flyer['Flyer'].plot.density(bw_method=0.2, color='red', linestyle='-', linewidth=2)
non_flyer2['Flyer'].plot.density(bw_method=0.2, color='yellow', linestyle='-', linewidth=2)
flyer2['Flyer'].plot.density(bw_method=0.2, color='green', linestyle='-', linewidth=2)
plt.legend(['non-flyer base','flyer base','non-flyer fine tuned','flyer fine tuned'])
plt.xlabel('Flyer index')
plt.xlim(0,1)
plt.show()
plt.savefig('density_sum.png')
plt.clf()
#Weak Flyer,Intermediate Flyer,Strong Flyer
non_flyer['actual_metrics'].plot.density(bw_method=0.2, color='blue', linestyle='-', linewidth=2)
flyer['actual_metrics'].plot.density(bw_method=0.2, color='red', linestyle='-', linewidth=2)
non_flyer2['actual_metrics'].plot.density(bw_method=0.2, color='yellow', linestyle='-', linewidth=2)
flyer2['actual_metrics'].plot.density(bw_method=0.2, color='green', linestyle='-', linewidth=2)
plt.legend(['non-flyer base','flyer base','non-flyer fine tuned','flyer fine tuned'])
plt.xlabel('Flyer index')
plt.xlim(-1,1)
plt.show()
plt.savefig('density_base.png')
def main():
total_num_classes = len(CLASSES_LABELS)
input_dimension = len(alphabet)
num_cells = 64
fine_tuned_model = DetectabilityModel(num_units=num_cells, num_clases=total_num_classes)
load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL'
fine_tuned_model = DetectabilityModel(num_units=num_cells,
num_clases=total_num_classes)
fine_tuned_model.load_weights(load_model_path)
max_pep_length = 40
## Has no impact for prediction
batch_size = 16
print('Initialising dataset')
## Data init
fine_tune_data = DetectabilityDataset(data_source='temp_fine_tune_df_train.csv',
val_data_source='temp_fine_tune_df_val.csv',
data_format='csv',
max_seq_len=max_pep_length,
label_column="Classes",
sequence_column="Sequences",
dataset_columns_to_keep=["Proteins"],
batch_size=batch_size,
with_termini=False,
alphabet=aa_to_int_dict)
# compile the model with the optimizer and the metrics we want to use.
callback_FT = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
mode='min',
verbose=1,
patience=5)
model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'
model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,
monitor='val_loss',
mode='min',
verbose=1,
save_best_only=True,
save_weights_only=True)
opti = tf.keras.optimizers.legacy.Adagrad()
fine_tuned_model.compile(optimizer=opti,
loss='SparseCategoricalCrossentropy',
metrics='sparse_categorical_accuracy')
history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
validation_data=fine_tune_data.tensor_val_data,
epochs=1,
callbacks=[callback_FT, model_checkpoint_FT])
## Loading best model weights
# model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'
model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
fine_tuned_model.load_weights(model_save_path_FT)
predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_val_data)
# access val dataset and get the Classes column
test_targets_FT = fine_tune_data["val"]["Classes"]
# if needed, the decoded version of the classes can be retrieved by looking up the class names
test_targets_decoded_FT = [CLASSES_LABELS[x] for x in test_targets_FT]
# The dataframe needed for the report
test_data_df_FT = pd.DataFrame(
{
"Sequences": fine_tune_data["val"]["_parsed_sequence"], # get the raw parsed sequences
"Classes": test_targets_FT, # get the test targets from above
"Proteins": fine_tune_data["val"]["Proteins"] # get the Proteins column from the dataset object
}
)
test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x))
# Since the detectabiliy report expects the true labels in one-hot encoded format, we expand them here.
num_classes = np.max(test_targets_FT) + 1
test_targets_FT_one_hot = np.eye(num_classes)[test_targets_FT]
test_targets_FT_one_hot.shape, len(test_targets_FT)
report_FT = DetectabilityReport(test_targets_FT_one_hot,
predictions_FT,
test_data_df_FT,
output_path='./output/report_on_ISA (Base model categorical train, binary val )',
history=history_fine_tuned,
rank_by_prot=True,
threshold=None,
name_of_dataset='ISA val dataset (binary balanced)',
name_of_model='Base model (ISA)')
report_FT.generate_report()
if __name__ == '__main__':
create_ISA_dataset()
main()
# density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')