Skip to content
Snippets Groups Projects
Commit 89df63ba authored by Schneider Leo's avatar Schneider Leo
Browse files

astral dataset

parent 7537263a
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ from datasets import load_dataset, DatasetDict ...@@ -3,7 +3,7 @@ from datasets import load_dataset, DatasetDict
df_list =["Wilhelmlab/detectability-proteometools", "Wilhelmlab/detectability-wang","Wilhelmlab/detectability-sinitcyn"] df_list =["Wilhelmlab/detectability-proteometools", "Wilhelmlab/detectability-wang","Wilhelmlab/detectability-sinitcyn"]
df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv') df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv') df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
for label_type in ['Classes fragment','Classes precursor', 'Classes MaxLFQ'] : for label_type in ['Classes fragment','Classes precursor', 'Classes MaxLFQ'] :
......
...@@ -11,7 +11,7 @@ binary_labels = {0: "Non-Flyer", 1: "Flyer"} ...@@ -11,7 +11,7 @@ binary_labels = {0: "Non-Flyer", 1: "Flyer"}
""" """
def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'): def build_dataset(coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'):
df = pd.read_excel('ISA_data/250326_gut_microbiome_std_17_proteomes_data_training_detectability.xlsx') df = pd.read_excel('ISA_data/250326_gut_microbiome_std_17_proteomes_data_training_detectability.xlsx')
df_non_flyer = pd.read_csv('ISA_data/250422_FASTA_17_proteomes_gut_std_ozyme_+_conta_peptides_digested_filtered.csv') df_non_flyer = pd.read_csv('ISA_data/250422_FASTA_17_proteomes_gut_std_ozyme_+_conta_peptides_digested_filtered.csv')
#No flyer #No flyer
...@@ -38,12 +38,12 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m ...@@ -38,12 +38,12 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m
dico_final={} dico_final={}
# iterate over each group # iterate over each group
for group_name, df_group in df1_grouped: for group_name, df_group in df1_grouped:
seq = df_group.sort_values(by=[intensity_col])['Stripped.Sequence'].to_list() seq = df_group.sort_values(by=['Fragment.Quant.Raw'])['Stripped.Sequence'].to_list()
value_frag = df_group.sort_values(by=[intensity_col])[intensity_col].to_list() value_frag = df_group.sort_values(by=['Fragment.Quant.Raw'])['Fragment.Quant.Raw'].to_list()
value_prec = df_group.sort_values(by=['Precursor.Quantity'])['Precursor.Quantity'].to_list() value_prec = df_group.sort_values(by=['Precursor.Quantity'])['Precursor.Quantity'].to_list()
value_prec_frag = df_group.sort_values(by=[intensity_col])['Precursor.Quantity'].to_list() value_prec_frag = df_group.sort_values(by=['Fragment.Quant.Raw'])['Precursor.Quantity'].to_list()
value_maxlfq = df_group.sort_values(by=['MaxLFQ'])['MaxLFQ'].to_list() value_maxlfq = df_group.sort_values(by=['MaxLFQ'])['MaxLFQ'].to_list()
value_maxlfq_frag = df_group.sort_values(by=[intensity_col])['MaxLFQ'].to_list() value_maxlfq_frag = df_group.sort_values(by=['Fragment.Quant.Raw'])['MaxLFQ'].to_list()
threshold_weak_flyer_frag = value_frag[int(len(seq) / 3)] threshold_weak_flyer_frag = value_frag[int(len(seq) / 3)]
threshold_medium_flyer_frag = value_frag[int(2*len(seq) / 3)] threshold_medium_flyer_frag = value_frag[int(2*len(seq) / 3)]
threshold_weak_flyer_prec = value_prec[int(len(seq) / 3)] threshold_weak_flyer_prec = value_prec[int(len(seq) / 3)]
...@@ -82,5 +82,59 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m ...@@ -82,5 +82,59 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m
df_final=df_final[['Sequences','Proteins','Classes fragment','Classes precursor', 'Classes MaxLFQ']] df_final=df_final[['Sequences','Proteins','Classes fragment','Classes precursor', 'Classes MaxLFQ']]
df_final.to_csv(f_name, index=False) df_final.to_csv(f_name, index=False)
df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage.csv', index=False) df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage.csv', index=False)
def build_dataset_astral(coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'):
df = pd.read_excel('ISA_data/250505_Flyers_ASTRAL_mix_12_species.xlsx')
df_non_flyer = pd.read_excel('ISA_data/250505_Non_flyers_ASTRAL_mix_12_species.xlsx')
#No flyer
df_non_flyer = df_non_flyer[df_non_flyer['Cystein ?']==0]
df_non_flyer = df_non_flyer[pd.isna(df_non_flyer['Miscleavage ?'])]
df_non_flyer = df_non_flyer[pd.isna(df_non_flyer['MaxLFQ'])]
df_non_flyer['Sequences'] = df_non_flyer['Peptide']
df_non_flyer['Proteins'] = df_non_flyer['ProteinID']
df_non_flyer=df_non_flyer[['Sequences','Proteins']].drop_duplicates()
df_non_flyer['Classes MaxLFQ'] =0
#Flyer
df_filtered = df[~(pd.isna(df['Proteotypic ?']))]
df_filtered = df_filtered[df_filtered['Coverage']>=coverage_treshold]
df_filtered = df_filtered[pd.isna(df_filtered['Miscleavage ? '])]
peptide_count=df_filtered.groupby(["Protein.Names"]).size().reset_index(name='counts')
filtered_sequence = peptide_count[peptide_count['counts']>=min_peptide]["Protein.Names"]
df_filtered = df_filtered[df_filtered["Protein.Names"].isin(filtered_sequence.to_list())]
df1_grouped = df_filtered.groupby("Protein.Names")
dico_final={}
# iterate over each group
for group_name, df_group in df1_grouped:
seq = df_group.sort_values(by=['20250129_ISA_MIX-1_48SPD_001'])['Stripped.Sequence'].to_list()
value_maxlfq = df_group.sort_values(by=['20250129_ISA_MIX-1_48SPD_001'])['20250129_ISA_MIX-1_48SPD_001'].to_list()
value_maxlfq_frag = df_group.sort_values(by=['20250129_ISA_MIX-1_48SPD_001'])['20250129_ISA_MIX-1_48SPD_001'].to_list()
threshold_weak_flyer_maxflq = value_maxlfq[int(len(seq) / 3)]
threshold_medium_flyer_maxlfq = value_maxlfq[int(2 * len(seq) / 3)]
prot = df_group['Protein.Group'].to_list()[0]
for i in range(len(seq)):
if value_maxlfq_frag[i] < threshold_weak_flyer_maxflq :
label_maxlfq = 1
elif value_maxlfq_frag[i] < threshold_medium_flyer_maxlfq :
label_maxlfq = 2
else :
label_maxlfq = 3
dico_final[seq[i]] = (prot,label_maxlfq)
df_final = pd.DataFrame.from_dict(dico_final, orient='index',columns=['Proteins', 'Classes MaxLFQ'])
df_final['Sequences']=df_final.index
df_final = df_final.reset_index()
df_final=df_final[['Sequences','Proteins', 'Classes MaxLFQ']]
df_final.to_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv', index=False)
df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv', index=False)
if __name__ == '__main__': if __name__ == '__main__':
build_dataset( coverage_treshold=20, min_peptide=4, f_name='ISA_data/df_finetune_no_miscleavage.csv') build_dataset_astral(coverage_treshold=20, min_peptide=15)
...@@ -9,8 +9,8 @@ from datasets import load_dataset, DatasetDict ...@@ -9,8 +9,8 @@ from datasets import load_dataset, DatasetDict
from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report
def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=2): def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv') df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv') df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
df_no_flyer['Classes'] = df_no_flyer[classe_type] df_no_flyer['Classes'] = df_no_flyer[classe_type]
df_no_flyer = df_no_flyer[['Sequences', 'Classes']] df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
...@@ -28,13 +28,83 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0 ...@@ -28,13 +28,83 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0
list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :]) list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :]) list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
df_train['Proteins']=0 df_train['Proteins']=0
df_val['Proteins'] = 0 df_val['Proteins'] = 0
df_train.to_csv('temp_fine_tune_df_train.csv', index=False) df_train.to_csv('df_preprocessed/df_train_ISA.csv', index=False)
df_val.to_csv('temp_fine_tune_df_val.csv',index=False) df_val.to_csv('df_preprocessed/df_val_ISA_multiclass.csv',index=False)
def create_astral_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ']
df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
df_flyer['Classes']=df_flyer['Classes MaxLFQ']
df_flyer = df_flyer[['Sequences','Classes']]
#stratified split
list_train_split=[]
list_val_split =[]
for cl in [1,2,3]:
df_class = df_flyer[df_flyer['Classes']==cl]
class_count = df_class.shape[0]
list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
df_train['Proteins']=0
df_val['Proteins'] = 0
df_train.to_csv('df_preprocessed/df_train_astral_15.csv', index=False)
df_val.to_csv('df_preprocessed/df_val_astral_multiclass_15.csv',index=False)
def create_combine_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
df_flyer_astral = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral_7.csv')
df_no_flyer_astral = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
df_no_flyer_astral['Classes'] = df_no_flyer_astral['Classes MaxLFQ']
df_no_flyer_astral = df_no_flyer_astral[['Sequences', 'Classes']]
df_flyer_astral['Classes']=df_flyer_astral['Classes MaxLFQ']
df_flyer_astral = df_flyer_astral[['Sequences','Classes']]
df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage.csv')
df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
df_no_flyer['Classes'] = df_no_flyer['Classes MaxLFQ']
df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
df_flyer['Classes'] = df_flyer['Classes MaxLFQ']
df_flyer = df_flyer[['Sequences', 'Classes']]
#stratified split
list_train_split=[]
list_val_split =[]
for cl in [1,2,3]:
df_class_astral = df_flyer_astral[df_flyer_astral['Classes']==cl]
class_count_astral = df_class_astral.shape[0]
df_class_ISA = df_flyer[df_flyer['Classes'] == cl]
class_count_ISA = df_class_ISA.shape[0]
list_train_split.append(df_class_astral.iloc[:int(class_count_astral*split[0]),:])
list_val_split.append(df_class_astral.iloc[int(class_count_astral * split[0]):, :])
list_train_split.append(df_class_ISA.iloc[:int(class_count_ISA * split[0]), :])
list_val_split.append(df_class_ISA.iloc[int(class_count_ISA * split[0]):, :])
list_train_split.append(df_no_flyer_astral.iloc[:int(class_count_astral * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer_astral.iloc[df_no_flyer_astral.shape[0]-int(class_count_astral * split[1] * frac_no_fly_val):, :])
list_train_split.append(df_no_flyer.iloc[:int(class_count_ISA * split[0] * frac_no_fly_train), :])
list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0] - int(class_count_ISA * split[1] * frac_no_fly_val):, :])
df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
df_train['Proteins']=0
df_val['Proteins'] = 0
df_train.to_csv('df_preprocessed/df_train_combined_7.csv', index=False)
df_val.to_csv('df_preprocessed/df_val_combined_multiclass_7.csv',index=False)
def density_plot(prediction_path,prediction_path_2,criteria='base'): def density_plot(prediction_path,prediction_path_2,criteria='base'):
df = pd.read_csv(prediction_path) df = pd.read_csv(prediction_path)
...@@ -76,10 +146,8 @@ def density_plot(prediction_path,prediction_path_2,criteria='base'): ...@@ -76,10 +146,8 @@ def density_plot(prediction_path,prediction_path_2,criteria='base'):
def main(): def main():
total_num_classes = len(CLASSES_LABELS) total_num_classes = len(CLASSES_LABELS)
input_dimension = len(alphabet)
num_cells = 64 num_cells = 64
fine_tuned_model = DetectabilityModel(num_units=num_cells, num_clases=total_num_classes)
load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL'
fine_tuned_model = DetectabilityModel(num_units=num_cells, fine_tuned_model = DetectabilityModel(num_units=num_cells,
num_clases=total_num_classes) num_clases=total_num_classes)
...@@ -94,8 +162,8 @@ def main(): ...@@ -94,8 +162,8 @@ def main():
print('Initialising dataset') print('Initialising dataset')
## Data init ## Data init
fine_tune_data = DetectabilityDataset(data_source='temp_fine_tune_df_train.csv', fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_combined_15.csv',
val_data_source='temp_fine_tune_df_val.csv', val_data_source='df_preprocessed/df_val_combined_multiclass_15.csv',
data_format='csv', data_format='csv',
max_seq_len=max_pep_length, max_seq_len=max_pep_length,
label_column="Classes", label_column="Classes",
...@@ -114,7 +182,7 @@ def main(): ...@@ -114,7 +182,7 @@ def main():
verbose=1, verbose=1,
patience=5) patience=5)
model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability' model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined'
model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT, model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,
monitor='val_loss', monitor='val_loss',
...@@ -131,13 +199,13 @@ def main(): ...@@ -131,13 +199,13 @@ def main():
history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data, history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
validation_data=fine_tune_data.tensor_val_data, validation_data=fine_tune_data.tensor_val_data,
epochs=1, epochs=150,
callbacks=[callback_FT, model_checkpoint_FT]) callbacks=[callback_FT, model_checkpoint_FT])
## Loading best model weights ## Loading best model weights
# model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability' model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data
model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
fine_tuned_model.load_weights(model_save_path_FT) fine_tuned_model.load_weights(model_save_path_FT)
...@@ -170,16 +238,17 @@ def main(): ...@@ -170,16 +238,17 @@ def main():
report_FT = DetectabilityReport(test_targets_FT_one_hot, report_FT = DetectabilityReport(test_targets_FT_one_hot,
predictions_FT, predictions_FT,
test_data_df_FT, test_data_df_FT,
output_path='./output/report_on_ISA (Base model categorical train, binary val )', output_path='./output/report_on_combined_15 (Fine tuned model (combined_10) categorical train, categorical val)',
history=history_fine_tuned, history=history_fine_tuned,
rank_by_prot=True, rank_by_prot=True,
threshold=None, threshold=None,
name_of_dataset='ISA val dataset (binary balanced)', name_of_dataset='combined_15 val dataset (categorical balanced)',
name_of_model='Base model (ISA)') name_of_model='Fine tuned model (combined_15)')
report_FT.generate_report() report_FT.generate_report()
if __name__ == '__main__': if __name__ == '__main__':
create_ISA_dataset() # create_astral_dataset()
# create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1)
main() main()
# density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv') # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment