Skip to content
Snippets Groups Projects
Commit 0a226689 authored by Schneider Leo's avatar Schneider Leo
Browse files

binary pretraining

parent 89df63ba
No related branches found
No related tags found
No related merge requests found
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from dlomix.models import DetectabilityModel
from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict
from dlomix.data import DetectabilityDataset
from os.path import join, exists
from os import makedirs
from sklearn.metrics import ConfusionMatrixDisplay, auc, confusion_matrix, roc_curve
def create_ISA_binary_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
df_no_flyer['Classes'] = 0
df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
df_flyer['Classes'] = 1
df_flyer = df_flyer[['Sequences','Classes']]
#stratified split
list_train_split=[]
list_val_split =[]
flyer_count = df_flyer.shape[0] #rebalanced flyer and no flyer
list_train_split.append(df_flyer.iloc[:int(flyer_count*split[0]),:])
list_val_split.append(df_flyer.iloc[int(flyer_count * split[0]):, :])
list_train_split.append(df_no_flyer.iloc[:int(flyer_count * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(flyer_count * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
df_train['Proteins'] = 0
df_val['Proteins'] = 0
df_train.to_csv('df_preprocessed/df_train_ISA_binary.csv', index=False)
df_val.to_csv('df_preprocessed/df_val_ISA_binary.csv',index=False)
def create_astral_binary_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
df_no_flyer['Classes'] = 0
df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
df_flyer['Classes'] = 1
df_flyer = df_flyer[['Sequences','Classes']]
#stratified split
list_train_split=[]
list_val_split =[]
flyer_count = df_flyer.shape[0] #rebalanced flyer and no flyer
list_train_split.append(df_flyer.iloc[:int(flyer_count*split[0]),:])
list_val_split.append(df_flyer.iloc[int(flyer_count * split[0]):, :])
list_train_split.append(df_no_flyer.iloc[:int(flyer_count * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(flyer_count * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
df_train['Proteins'] = 0
df_val['Proteins'] = 0
df_train.to_csv('df_preprocessed/df_train_astral_binary.csv', index=False)
df_val.to_csv('df_preprocessed/df_val_astral_binary.csv',index=False)
def main(epoch):
total_num_classes = 2
num_cells = 64
load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL'
fine_tuned_model = DetectabilityModel(num_units=num_cells,
num_clases=2)
fine_tuned_model.build((None, 40))
base_arch = DetectabilityModel(num_units=num_cells,
num_clases=4)
base_arch.load_weights(load_model_path)
#partially loading pretrained weights (multiclass training)
base_arch.build((None, 40))
weights_list = base_arch.get_weights()
weights_list[-1] = np.array([0.,0.],dtype=np.float32)
weights_list[-2] = np.zeros((128,2),dtype=np.float32)
fine_tuned_model.set_weights(weights_list)
max_pep_length = 40
## Has no impact for prediction
batch_size = 16
print('Initialising dataset')
## Data init
fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_astral_binary.csv',
val_data_source='df_preprocessed/df_val_astral_binary.csv',
data_format='csv',
max_seq_len=max_pep_length,
label_column="Classes",
sequence_column="Sequences",
dataset_columns_to_keep=["Proteins"],
batch_size=batch_size,
with_termini=False,
alphabet=aa_to_int_dict)
# compile the model with the optimizer and the metrics we want to use.
callback_FT = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
mode='min',
verbose=1,
patience=5)
model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined'
model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,
monitor='val_loss',
mode='min',
verbose=1,
save_best_only=True,
save_weights_only=True)
opti = tf.keras.optimizers.legacy.Adagrad()
fine_tuned_model.compile(optimizer=opti,
loss='SparseCategoricalCrossentropy',
metrics='sparse_categorical_accuracy')
history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
validation_data=fine_tune_data.tensor_val_data,
epochs=epoch,
callbacks=[callback_FT, model_checkpoint_FT])
## Loading best model weights
model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data
# model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
fine_tuned_model.load_weights(model_save_path_FT)
predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_val_data)
# access val dataset and get the Classes column
test_targets_FT = fine_tune_data["val"]["Classes"]
# The dataframe needed for the report
test_data_df_FT = pd.DataFrame(
{
"Sequences": fine_tune_data["val"]["_parsed_sequence"], # get the raw parsed sequences
"Classes": test_targets_FT, # get the test targets from above
"Proteins": fine_tune_data["val"]["Proteins"], # get the Proteins column from the dataset object
"Prob non flyer": predictions_FT[:,0],
"Prob flyer": predictions_FT[:, 1],
"Predicted classes" : np.argmax(predictions_FT,axis=1)
}
)
test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x))
return test_data_df_FT, history_fine_tuned
def plot_and_save_metrics(history, base_path):
history_dict = history.history
metrics = history_dict.keys()
metrics = filter(lambda x: not x.startswith(tuple(["val_", "_"])), metrics)
if not exists(base_path):
makedirs(base_path)
for metric_name in metrics:
plt.plot(history_dict[metric_name])
plt.plot(history_dict["val_" + metric_name])
plt.title(metric_name, fontsize=10) # Modified Original plt.title(metric_name)
plt.ylabel(metric_name)
plt.xlabel("epoch")
plt.legend(["train", "val"], loc="best")
save_path = join(base_path, metric_name)
plt.savefig(
save_path, bbox_inches="tight", dpi=90
) # Modification Original plt.savefig(save_path)
plt.close()
def plot_confusion_matrix(df, base_path):
conf_matrix = confusion_matrix(
df["Classes"],
df["Predicted classes"],
)
if not exists(base_path):
makedirs(base_path)
conf_matrix_disp = ConfusionMatrixDisplay(
confusion_matrix=conf_matrix, display_labels=["Non-Flyer", "Flyer"]
)
fig, ax = plt.subplots()
conf_matrix_disp.plot(xticks_rotation=45, ax=ax)
plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11)
save_path = join(base_path, "confusion_matrix_binary"
)
plt.savefig(save_path, bbox_inches="tight", dpi=80)
plt.close()
def plot_roc(df,base_path):
fpr, tpr, thresholds = roc_curve(
np.array(df["Classes"]),
np.array(df["Prob flyer"]),
)
AUC_score = auc(fpr, tpr)
# create ROC curve
plt.plot(fpr, tpr, label="ROC curve of (area = {})".format(AUC_score))
plt.title(
"Receiver operating characteristic curve (Binary classification)",
y=1.04,
fontsize=10,
)
plt.ylabel("True Positive Rate")
plt.xlabel("False Positive Rate")
save_path = join(
base_path, "ROC_curve_binary_classification"
)
plt.savefig(save_path, bbox_inches="tight", dpi=90)
plt.close()
if __name__ == '__main__':
create_astral_binary_dataset()
test_data_df_FT, history = main(epoch=150)
base_path = 'output/binary_astral'
plot_and_save_metrics(history,base_path)
plot_confusion_matrix(test_data_df_FT,base_path)
plot_roc(test_data_df_FT,base_path)
\ No newline at end of file
Sequences,Classes,Proteins Sequences,Classes,Proteins
IVDDLSALTVLEASELSK,0,0 IVDDLSALTVLEASELSK,0,0
\ No newline at end of file IVDDLSALTVLEASELSK,1,0
IVDDLSALTVLEASELSK,1,0
IVDDLSALTVLEASELSK,0,0
IVDDLSALTVLEASELSK,0,0
IVDDLSALTVLEASELSK,0,0
IVDDLSALTVLEASELSK,0,0
...@@ -19,14 +19,16 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0 ...@@ -19,14 +19,16 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0
#stratified split #stratified split
list_train_split=[] list_train_split=[]
list_val_split =[] list_val_split =[]
total_count = 0
for cl in [1,2,3]: for cl in [1,2,3]:
df_class = df_flyer[df_flyer['Classes']==cl] df_class = df_flyer[df_flyer['Classes']==cl]
class_count = df_class.shape[0] class_count = df_class.shape[0]
list_train_split.append(df_class.iloc[:int(class_count*split[0]),:]) list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
list_val_split.append(df_class.iloc[int(class_count * split[0]):, :]) list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
total_count+=class_count
list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :]) total_count=total_count/3
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :]) list_train_split.append(df_no_flyer.iloc[:int(total_count * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(total_count * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
...@@ -37,7 +39,7 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0 ...@@ -37,7 +39,7 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0
df_val.to_csv('df_preprocessed/df_val_ISA_multiclass.csv',index=False) df_val.to_csv('df_preprocessed/df_val_ISA_multiclass.csv',index=False)
def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1): def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv') df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv') df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ'] df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ']
df_no_flyer = df_no_flyer[['Sequences', 'Classes']] df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
...@@ -46,22 +48,25 @@ def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1 ...@@ -46,22 +48,25 @@ def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1
#stratified split #stratified split
list_train_split=[] list_train_split=[]
list_val_split =[] list_val_split =[]
total_count = 0
for cl in [1,2,3]: for cl in [1,2,3]:
df_class = df_flyer[df_flyer['Classes']==cl] df_class = df_flyer[df_flyer['Classes']==cl]
class_count = df_class.shape[0] class_count = df_class.shape[0]
list_train_split.append(df_class.iloc[:int(class_count*split[0]),:]) list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
list_val_split.append(df_class.iloc[int(class_count * split[0]):, :]) list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
total_count += class_count
total_count = total_count / 3
list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :]) list_train_split.append(df_no_flyer.iloc[:int(total_count * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :]) list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(total_count * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
df_train['Proteins']=0 df_train['Proteins']=0
df_val['Proteins'] = 0 df_val['Proteins'] = 0
df_train.to_csv('df_preprocessed/df_train_astral_15.csv', index=False) df_train.to_csv('df_preprocessed/df_train_astral_4.csv', index=False)
df_val.to_csv('df_preprocessed/df_val_astral_multiclass_15.csv',index=False) df_val.to_csv('df_preprocessed/df_val_astral_binary_4.csv',index=False)
def create_combine_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1): def create_combine_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
df_flyer_astral = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_7.csv') df_flyer_astral = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_7.csv')
...@@ -152,7 +157,7 @@ def main(): ...@@ -152,7 +157,7 @@ def main():
fine_tuned_model = DetectabilityModel(num_units=num_cells, fine_tuned_model = DetectabilityModel(num_units=num_cells,
num_clases=total_num_classes) num_clases=total_num_classes)
fine_tuned_model.load_weights(load_model_path) # fine_tuned_model.load_weights(load_model_path)
...@@ -162,8 +167,8 @@ def main(): ...@@ -162,8 +167,8 @@ def main():
print('Initialising dataset') print('Initialising dataset')
## Data init ## Data init
fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_combined_15.csv', fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_astral_4.csv',
val_data_source='df_preprocessed/df_val_combined_multiclass_15.csv', val_data_source='df_preprocessed/df_val_astral_multiclass_4.csv',
data_format='csv', data_format='csv',
max_seq_len=max_pep_length, max_seq_len=max_pep_length,
label_column="Classes", label_column="Classes",
...@@ -199,12 +204,12 @@ def main(): ...@@ -199,12 +204,12 @@ def main():
history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data, history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
validation_data=fine_tune_data.tensor_val_data, validation_data=fine_tune_data.tensor_val_data,
epochs=150, epochs=450,
callbacks=[callback_FT, model_checkpoint_FT]) callbacks=[callback_FT, model_checkpoint_FT])
## Loading best model weights ## Loading best model weights
model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data # model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data
# model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
fine_tuned_model.load_weights(model_save_path_FT) fine_tuned_model.load_weights(model_save_path_FT)
...@@ -214,9 +219,6 @@ def main(): ...@@ -214,9 +219,6 @@ def main():
# access val dataset and get the Classes column # access val dataset and get the Classes column
test_targets_FT = fine_tune_data["val"]["Classes"] test_targets_FT = fine_tune_data["val"]["Classes"]
# if needed, the decoded version of the classes can be retrieved by looking up the class names
test_targets_decoded_FT = [CLASSES_LABELS[x] for x in test_targets_FT]
# The dataframe needed for the report # The dataframe needed for the report
test_data_df_FT = pd.DataFrame( test_data_df_FT = pd.DataFrame(
...@@ -238,17 +240,18 @@ def main(): ...@@ -238,17 +240,18 @@ def main():
report_FT = DetectabilityReport(test_targets_FT_one_hot, report_FT = DetectabilityReport(test_targets_FT_one_hot,
predictions_FT, predictions_FT,
test_data_df_FT, test_data_df_FT,
output_path='./output/report_on_combined_15 (Fine tuned model (combined_10) categorical train, categorical val)', output_path='./output/report_on_astral_4 (from scratch model categorical train, categorical val )',
history=history_fine_tuned, history=history_fine_tuned,
rank_by_prot=True, rank_by_prot=True,
threshold=None, threshold=None,
name_of_dataset='combined_15 val dataset (categorical balanced)', name_of_dataset='astral_4 val dataset (Categorical balanced)',
name_of_model='Fine tuned model (combined_15)') name_of_model='From scratch model')
report_FT.generate_report() report_FT.generate_report()
if __name__ == '__main__': if __name__ == '__main__':
# create_astral_dataset() create_astral_dataset(frac_no_fly_val=1)
# create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1) # create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1)
create_ISA_dataset(frac_no_fly_val=1)
main() main()
# density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv') # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment