diff --git a/binary_training.py b/binary_training.py index f9a4a64be368fdf5834526af84686d3e34c2f0f9..34854c63936f6ef2a96a0a1cf2f798a370eedd70 100644 --- a/binary_training.py +++ b/binary_training.py @@ -36,8 +36,8 @@ def create_ISA_binary_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_tra df_val.to_csv('df_preprocessed/df_val_ISA_binary.csv',index=False) def create_astral_binary_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1): - df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral.csv') - df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv') + df_flyer = pd.read_csv('ISA_data/datasets/df_flyer_no_miscleavage_astral.csv') + df_no_flyer = pd.read_csv('ISA_data/datasets/df_non_flyer_no_miscleavage_astral.csv') df_no_flyer['Classes'] = 0 df_no_flyer = df_no_flyer[['Sequences', 'Classes']] df_flyer['Classes'] = 1 @@ -229,6 +229,7 @@ def plot_roc(df,base_path): if __name__ == '__main__': create_astral_binary_dataset() test_data_df_FT, history = main(epoch=150) + test_data_df_FT.to_csv('output/binary_astral/Dectetability_prediction_report.csv') base_path = 'output/binary_astral' plot_and_save_metrics(history,base_path) plot_confusion_matrix(test_data_df_FT,base_path) diff --git a/dataset_extraction.py b/dataset_extraction.py index ea2d685573f8a8a24369be4d4216881b0ee44a6f..b6610c7117e3ea09eecba0fecb26792d43548dd5 100644 --- a/dataset_extraction.py +++ b/dataset_extraction.py @@ -1,3 +1,6 @@ +from cProfile import label + +import matplotlib.pyplot as plt import pandas as pd """ @@ -135,6 +138,116 @@ def build_dataset_astral(coverage_treshold = 20, min_peptide = 4, f_name='out_df df_final.to_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv', index=False) df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv', index=False) +def build_regression_dataset_astral(coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'): + + df = pd.read_excel('ISA_data/250505_Flyers_ASTRAL_mix_12_species.xlsx') + df_non_flyer = pd.read_excel('ISA_data/250505_Non_flyers_ASTRAL_mix_12_species.xlsx') + #No flyer + df_non_flyer = df_non_flyer[df_non_flyer['Cystein ?']==0] + df_non_flyer = df_non_flyer[pd.isna(df_non_flyer['Miscleavage ?'])] + df_non_flyer = df_non_flyer[pd.isna(df_non_flyer['MaxLFQ'])] + df_non_flyer['Sequences'] = df_non_flyer['Peptide'] + df_non_flyer['Proteins'] = df_non_flyer['ProteinID'] + df_non_flyer=df_non_flyer[['Sequences','Proteins']].drop_duplicates() + df_non_flyer['Value fragment'] = 0 + df_non_flyer['Value precursor'] = 0 + df_non_flyer['Value MaxLFQ'] = 0 + + + #Flyer + quantites_table = pd.read_csv('ISA_data/250505_mix_12_souches_lib_12_especes_conta_ASTRAL_BIOASTER_quantities.csv') + + df_filtered = df[~(pd.isna(df['Proteotypic ?']))] + df_filtered = df_filtered[df_filtered['Coverage']>=coverage_treshold] + df_filtered = df_filtered[pd.isna(df_filtered['Miscleavage ? '])] + peptide_count=df_filtered.groupby(["Protein.Names"]).size().reset_index(name='counts') + filtered_sequence = peptide_count[peptide_count['counts']>=min_peptide]["Protein.Names"] + df_filtered = df_filtered[df_filtered["Protein.Names"].isin(filtered_sequence.to_list())] + quantites_table_filtered = quantites_table[quantites_table['Modified.Sequence'].isin(df_filtered['Stripped.Sequence'])] + df_filtered = pd.merge(quantites_table_filtered, df_filtered, how='inner',left_on='Modified.Sequence', right_on='Stripped.Sequence') + df1_grouped = df_filtered.groupby("Protein.Names") + dico_final={} + # iterate over each group + for group_name, df_group in df1_grouped: + seq = df_group['Stripped.Sequence'].to_list() + value_maxlfq = df_group['20250129_ISA_MIX-1_48SPD_001'].to_list() + value_frag = df_group['Fragment.Quant.Raw'].to_list() + value_prec = df_group['Precursor.Quantity'].to_list() + + prot = df_group['Protein.Group'].to_list()[0] + max_frag = max(value_frag) + max_prec = max(value_prec) + max_max_lfq = max(value_maxlfq) + for i in range(len(seq)): + label_frag = value_frag[i]/max_frag + label_prec = value_prec[i] / max_prec + label_maxlfq = value_maxlfq[i] / max_max_lfq + dico_final[seq[i]] = (prot,label_frag,label_prec,label_maxlfq) + + df_final = pd.DataFrame.from_dict(dico_final, orient='index',columns=['Proteins', 'Value fragment','Value precursor', 'Value MaxLFQ']) + df_final['Sequences']=df_final.index + df_final = df_final.reset_index() + df_final=df_final[['Sequences','Proteins','Value fragment','Value precursor', 'Value MaxLFQ']] + # df_final.to_csv('ISA_data/datasets/df_flyer_astral_reg_{}.csv'.format(min_peptide), index=False) + # df_non_flyer.to_csv('ISA_data/datasets/df_non_flyer_astral_reg.csv', index=False) + return df_final + +def build_dataset_regression_zeno(coverage_treshold = 20, min_peptide = 4): + df = pd.read_excel('ISA_data/250326_gut_microbiome_std_17_proteomes_data_training_detectability.xlsx') + df_non_flyer = pd.read_csv('ISA_data/250422_FASTA_17_proteomes_gut_std_ozyme_+_conta_peptides_digested_filtered.csv') + #No flyer + df_non_flyer = df_non_flyer[df_non_flyer['Cystein ? ']=='Any'] + df_non_flyer = df_non_flyer[df_non_flyer['Miscleavage ?'] == 'Any'] + df_non_flyer = df_non_flyer[df_non_flyer['MaxLFQ'] == 0.0] + df_non_flyer['Sequences'] = df_non_flyer['Peptide'] + df_non_flyer['Proteins'] = df_non_flyer['ProteinID'] + df_non_flyer=df_non_flyer[['Sequences','Proteins']].drop_duplicates() + df_non_flyer['Value fragment']=0 + df_non_flyer['Value precursor'] =0 + df_non_flyer['Value MaxLFQ'] =0 + + + #Flyer + df_filtered = df[df['Proteotypic ?']=='Proteotypic'] + df_filtered = df_filtered[df_filtered['Coverage ']>=coverage_treshold] + df_filtered = df_filtered[df_filtered['Miscleavage ?']=='Any'] + df_filtered = df_filtered[~df_filtered['Precursor.Quantity'].isna()] + peptide_count=df_filtered.groupby(["Protein.Names"]).size().reset_index(name='counts') + filtered_sequence = peptide_count[peptide_count['counts']>=min_peptide]["Protein.Names"] + df_filtered = df_filtered[df_filtered["Protein.Names"].isin(filtered_sequence.to_list())] + + df1_grouped = df_filtered.groupby("Protein.Names") + dico_final={} + # iterate over each group + for group_name, df_group in df1_grouped: + seq = df_group['Stripped.Sequence'].to_list() + value_frag = df_group['Fragment.Quant.Raw'].to_list() + value_prec = df_group['Precursor.Quantity'].to_list() + value_maxlfq = df_group['MaxLFQ'].to_list() + + prot = df_group['Protein.Group'].to_list()[0] + max_frag = max(value_frag) + max_prec = max(value_prec) + max_max_lfq = max(value_maxlfq) + for i in range(len(seq)): + label_frag = value_frag[i]/max_frag + label_prec = value_prec[i] / max_prec + label_maxlfq = value_maxlfq[i] / max_max_lfq + dico_final[seq[i]] = (prot,label_frag,label_prec,label_maxlfq) + + df_final = pd.DataFrame.from_dict(dico_final, orient='index',columns=['Proteins', 'Value fragment','Value precursor', 'Value MaxLFQ']) + df_final['Sequences']=df_final.index + df_final = df_final.reset_index() + df_final=df_final[['Sequences','Proteins','Value fragment','Value precursor', 'Value MaxLFQ']] + df_final.to_csv('ISA_data/datasets/df_flyer_zeno_reg.csv', index=False) + df_non_flyer.to_csv('ISA_data/datasets/df_non_flyer_zeno_reg.csv', index=False) + if __name__ == '__main__': - build_dataset_astral(coverage_treshold=20, min_peptide=15) + df_size=[] + for min_pep in range(4,20): + df = build_regression_dataset_astral(coverage_treshold=20, min_peptide=min_pep) + df_size.append(df.shape[0]) + plt.clf() + plt.bar([i for i in range(4,20)],df_size) + plt.savefig('number_of_peptides_thr.png') diff --git a/main_fine_tune.py b/main_fine_tune.py index d4ad775974e1d456f2ac2b720c71a2d9aab8b8f2..d58ba437cb859ce4616646896e957a80b300165b 100644 --- a/main_fine_tune.py +++ b/main_fine_tune.py @@ -167,7 +167,7 @@ def main(): print('Initialising dataset') ## Data init - fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_astral_4.csv', + fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_combined_4.csv', val_data_source='df_preprocessed/df_val_astral_multiclass_4.csv', data_format='csv', max_seq_len=max_pep_length, @@ -204,7 +204,7 @@ def main(): history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data, validation_data=fine_tune_data.tensor_val_data, - epochs=450, + epochs=150, callbacks=[callback_FT, model_checkpoint_FT]) ## Loading best model weights @@ -240,18 +240,18 @@ def main(): report_FT = DetectabilityReport(test_targets_FT_one_hot, predictions_FT, test_data_df_FT, - output_path='./output/report_on_astral_4 (from scratch model categorical train, categorical val )', + output_path='./output/report_on_astral_4 (Fine tune model (combined 4) categorical train, categorical val )', history=history_fine_tuned, rank_by_prot=True, threshold=None, name_of_dataset='astral_4 val dataset (Categorical balanced)', - name_of_model='From scratch model') + name_of_model='Fine tune model (combined 4)') report_FT.generate_report() if __name__ == '__main__': - create_astral_dataset(frac_no_fly_val=1) + # create_astral_dataset(frac_no_fly_val=1) # create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1) - create_ISA_dataset(frac_no_fly_val=1) + # create_ISA_dataset(frac_no_fly_val=1) main() # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv') diff --git a/regression_training.py b/regression_training.py new file mode 100644 index 0000000000000000000000000000000000000000..061b280761073854951e5b6d9e252fbebdce98af --- /dev/null +++ b/regression_training.py @@ -0,0 +1,277 @@ +import numpy as np +import pandas as pd +import tensorflow as tf +import matplotlib.pyplot as plt +from dlomix.models import DetectabilityModel +from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict +from dlomix.data import DetectabilityDataset +from os.path import join, exists +from os import makedirs +from sklearn.metrics import ConfusionMatrixDisplay, auc, confusion_matrix, roc_curve + +def create_zeno_reg_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=3,frac_no_fly_val=1): + df_flyer = pd.read_csv('ISA_data/datasets/df_flyer_zeno_reg.csv') + df_no_flyer = pd.read_csv('ISA_data/datasets/df_non_flyer_zeno_reg.csv') + df_flyer['Value']=df_flyer['Value precursor'].apply(np.sqrt) + df_no_flyer['Value'] = df_no_flyer['Value precursor'].apply(np.sqrt) + df_no_flyer = df_no_flyer[['Sequences', 'Value']] + df_flyer = df_flyer[['Sequences','Value']] + #stratified split + list_train_split=[] + list_val_split =[] + + flyer_count = df_flyer.shape[0] #rebalanced flyer and no flyer + list_train_split.append(df_flyer.iloc[:int(flyer_count*split[0]),:]) + list_val_split.append(df_flyer.iloc[int(flyer_count * split[0]):, :]) + + list_train_split.append(df_no_flyer.iloc[:int(flyer_count * split[0] * frac_no_fly_train), :]) + list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(flyer_count * split[1] * frac_no_fly_val):, :]) + + df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle + df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle + + df_train['Proteins'] = 0 + df_val['Proteins'] = 0 + df_train.to_csv('df_preprocessed/df_train_zeno_reg.csv', index=False) + df_val.to_csv('df_preprocessed/df_val_zeno_reg.csv',index=False) + +def create_astral_reg_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=2,frac_no_fly_val=1): + df_flyer = pd.read_csv('ISA_data/datasets/df_flyer_astral_reg_4.csv') + df_no_flyer = pd.read_csv('ISA_data/datasets/df_non_flyer_astral_reg.csv') + df_flyer['Value'] = df_flyer['Value precursor'].apply(np.sqrt) + df_no_flyer['Value'] = df_no_flyer['Value precursor'].apply(np.sqrt) + df_no_flyer = df_no_flyer[['Sequences', 'Value']] + df_flyer = df_flyer[['Sequences', 'Value']] + #stratified split + list_train_split=[] + list_val_split =[] + + flyer_count = df_flyer.shape[0] #rebalanced flyer and no flyer + list_train_split.append(df_flyer.iloc[:int(flyer_count*split[0]),:]) + list_val_split.append(df_flyer.iloc[int(flyer_count * split[0]):, :]) + + list_train_split.append(df_no_flyer.iloc[:int(flyer_count * split[0] * frac_no_fly_train), :]) + list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(flyer_count * split[1] * frac_no_fly_val):, :]) + + df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle + df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle + + df_train['Proteins'] = 0 + df_val['Proteins'] = 0 + df_train.to_csv('df_preprocessed/df_train_astral_reg.csv', index=False) + df_val.to_csv('df_preprocessed/df_val_astral_reg.csv',index=False) + + +def main(epoch): + total_num_classes = 2 + num_cells = 64 + + load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' + fine_tuned_model = DetectabilityModel(num_units=num_cells, + num_clases=1) + + fine_tuned_model.decoder.decoder_dense = tf.keras.layers.Dense(1, activation=None) + fine_tuned_model.build((None, 40)) + + base_arch = DetectabilityModel(num_units=num_cells, + num_clases=4) + base_arch.load_weights(load_model_path) + + #partially loading pretrained weights (multiclass training) + base_arch.build((None, 40)) + weights_list = base_arch.get_weights() + weights_list[-1] = np.array([0.],dtype=np.float32) + weights_list[-2] = np.zeros((128,1),dtype=np.float32) + fine_tuned_model.set_weights(weights_list) + + + + max_pep_length = 40 + ## Has no impact for prediction + batch_size = 16 + + print('Initialising dataset') + ## Data init + fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_zeno_reg.csv', + val_data_source='df_preprocessed/df_val_zeno_reg.csv', + data_format='csv', + max_seq_len=max_pep_length, + label_column="Value", + sequence_column="Sequences", + dataset_columns_to_keep=["Proteins"], + batch_size=batch_size, + with_termini=False, + alphabet=aa_to_int_dict) + + + + + # compile the model with the optimizer and the metrics we want to use. + callback_FT = tf.keras.callbacks.EarlyStopping(monitor='val_loss', + mode='min', + verbose=1, + patience=50) + + model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' + + model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT, + monitor='val_loss', + mode='min', + verbose=1, + save_best_only=True, + save_weights_only=True) + opti = tf.keras.optimizers.legacy.Adagrad() + fine_tuned_model.compile(optimizer=opti, + loss='MeanAbsoluteError', + metrics='RootMeanSquaredError') + + + + history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data, + validation_data=fine_tune_data.tensor_val_data, + epochs=epoch, + callbacks=[callback_FT, model_checkpoint_FT]) + ## Loading best model weights + + model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data + # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model + + fine_tuned_model.load_weights(model_save_path_FT) + + predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_val_data) + + # access val dataset and get the Classes column + test_targets_FT = fine_tune_data["val"]["Value"] + + # The dataframe needed for the report + + test_data_df_FT = pd.DataFrame( + { + "Sequences": fine_tune_data["val"]["_parsed_sequence"], # get the raw parsed sequences + "Classes": test_targets_FT, # get the test targets from above + "Proteins": fine_tune_data["val"]["Proteins"], # get the Proteins column from the dataset object + "Prediction": predictions_FT[:,0], + } + ) + + + + test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x)) + + return test_data_df_FT, history_fine_tuned + +def plot_and_save_metrics(history, base_path): + history_dict = history.history + metrics = history_dict.keys() + metrics = filter(lambda x: not x.startswith(tuple(["val_", "_"])), metrics) + + if not exists(base_path): + makedirs(base_path) + + for metric_name in metrics: + plt.plot(history_dict[metric_name]) + plt.plot(history_dict["val_" + metric_name]) + plt.title(metric_name, fontsize=10) # Modified Original plt.title(metric_name) + plt.ylabel(metric_name) + plt.xlabel("epoch") + plt.legend(["train", "val"], loc="best") + save_path = join(base_path, metric_name) + plt.savefig( + save_path, bbox_inches="tight", dpi=90 + ) # Modification Original plt.savefig(save_path) + plt.close() + +def plot_confusion_matrix(df, base_path): + conf_matrix = confusion_matrix( + df["Classes"], + df["Predicted classes"], + ) + + if not exists(base_path): + makedirs(base_path) + + conf_matrix_disp = ConfusionMatrixDisplay( + confusion_matrix=conf_matrix, display_labels=["Non-Flyer", "Flyer"] + ) + fig, ax = plt.subplots() + conf_matrix_disp.plot(xticks_rotation=45, ax=ax) + plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11) + save_path = join(base_path, "confusion_matrix_binary" + ) + plt.savefig(save_path, bbox_inches="tight", dpi=80) + plt.close() + +def plot_roc(df,base_path): + fpr, tpr, thresholds = roc_curve( + np.array(df["Classes"]), + np.array(df["Prediction"]), + ) + AUC_score = auc(fpr, tpr) + + # create ROC curve + + plt.plot(fpr, tpr, label="ROC curve of (area = {})".format(AUC_score)) + plt.title( + "Receiver operating characteristic curve (Binary classification)", + y=1.04, + fontsize=10, + ) + + plt.ylabel("True Positive Rate") + plt.xlabel("False Positive Rate") + save_path = join( + base_path, "ROC_curve_binary_classification" + ) + + plt.savefig(save_path, bbox_inches="tight", dpi=90) + plt.close() + +def optimize_threshold(df,frac): + tr_list=[] + acc_list=[] + for tr in range(200): + df['Binary Classes'] = df['Classes'] != 0 + df['Binary Prediction'] = df['Prediction'] >= tr/200 + conf_matrix = confusion_matrix( + df["Binary Classes"], + df["Binary Prediction"], + ) + tr_list.append(tr/200) + acc_list.append((conf_matrix[0,0]+conf_matrix[1,1])/conf_matrix.sum()) + plt.clf() + plt.plot(tr_list,acc_list) + plt.savefig('acc_tuning_zeno_sqrt_mae{}.png'.format(frac)) + plt.clf() + print(max(acc_list)) + #best conf matrix + tr = tr_list[np.argmax(acc_list)] + df['Binary Classes'] = df['Classes'] != 0 + df['Binary Prediction'] = df['Prediction'] >= tr + conf_matrix = confusion_matrix( + df["Binary Classes"], + df["Binary Prediction"], + ) + + conf_matrix_disp = ConfusionMatrixDisplay( + confusion_matrix=conf_matrix, display_labels=["Non-Flyer", "Flyer"] + ) + fig, ax = plt.subplots() + conf_matrix_disp.plot(xticks_rotation=45, ax=ax) + plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11) + save_path = join('output', "confusion_matrix_zeno_reg_sqrt_mae{}".format(frac) + ) + plt.savefig(save_path, bbox_inches="tight", dpi=80) + plt.close() + + return tr,acc_list + + +if __name__ == '__main__': + for f in [1,1.5]: + create_zeno_reg_dataset(frac_no_fly_train=f) + test_data_df_FT, history = main(epoch=150) + optimize_threshold(test_data_df_FT,f) + # base_path = 'output/binary_astral' + # plot_and_save_metrics(history,base_path) + # plot_confusion_matrix(test_data_df_FT,base_path) + # plot_roc(test_data_df_FT,base_path) \ No newline at end of file diff --git a/threshold_opti_categorical.py b/threshold_opti_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..588d2e4e08400f9b499e08ca28f12326c79f10c3 --- /dev/null +++ b/threshold_opti_categorical.py @@ -0,0 +1,80 @@ +import pandas as pd +import matplotlib.pyplot as plt +from sklearn.metrics import ConfusionMatrixDisplay, auc, confusion_matrix, roc_curve +import numpy as np + +def optimize_tr_categorical(df_path): + df = pd.read_csv(df_path) + df['Bool Classes']=df['Binary Classes']!='Non-Flyer' + tr_list = [] + acc_list =[] + for tr in range(200): + df['Opti Classes']=df['Non-Flyer']<=tr/200 + conf_matrix = confusion_matrix( + df["Bool Classes"], + df["Opti Classes"] + ) + tr_list.append(tr/200) + acc_list.append((conf_matrix[0,0]+conf_matrix[1,1])/conf_matrix.sum()) + plt.clf() + plt.plot(tr_list,acc_list) + plt.legend(['max acc : {}'.format(max(acc_list))]) + plt.savefig('acc_opti_zeno.png') + plt.clf() + tr = tr_list[np.argmax(acc_list)] + + df['Opti Classes']=df['Non-Flyer']<=tr + conf_matrix = confusion_matrix( + df["Bool Classes"], + df["Opti Classes"] + ) + + conf_matrix_disp = ConfusionMatrixDisplay( + confusion_matrix=conf_matrix, display_labels=["Non-Flyer","Flyer"] + ) + fig, ax = plt.subplots() + conf_matrix_disp.plot(xticks_rotation=45, ax=ax) + plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11) + plt.savefig('best_conf_matrix_zeno_4.png', bbox_inches="tight", dpi=80) + plt.close() + + +def optimize_tr_binary(df_path): + df = pd.read_csv(df_path) + df['Bool Classes']=df['Classes']!=0 + tr_list = [] + acc_list =[] + for tr in range(200): + df['Opti Classes']=df['Prob non flyer']<=tr/200 + conf_matrix = confusion_matrix( + df["Bool Classes"], + df["Opti Classes"] + ) + tr_list.append(tr/200) + acc_list.append((conf_matrix[0,0]+conf_matrix[1,1])/conf_matrix.sum()) + plt.clf() + plt.plot(tr_list,acc_list) + plt.legend(['max acc : {}'.format(max(acc_list))]) + plt.savefig('acc_opti_binary_astral.png') + plt.clf() + tr = tr_list[np.argmax(acc_list)] + + df['Opti Classes']=df['Prob non flyer']<=tr + conf_matrix = confusion_matrix( + df["Bool Classes"], + df["Opti Classes"] + ) + + conf_matrix_disp = ConfusionMatrixDisplay( + confusion_matrix=conf_matrix, display_labels=["Non-Flyer","Flyer"] + ) + fig, ax = plt.subplots() + conf_matrix_disp.plot(xticks_rotation=45, ax=ax) + plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11) + plt.savefig('best_conf_matrix_binary_astral.png', bbox_inches="tight", dpi=80) + plt.close() + +if __name__ == '__main__': + # optimize_tr_categorical( + # 'output/report_on_ISA (Fine tune model categorical train, binary val )/Dectetability_prediction_report.csv') + optimize_tr_binary('output/binary_astral/Dectetability_prediction_report.csv') \ No newline at end of file