diff --git a/binary_training.py b/binary_training.py
index f9a4a64be368fdf5834526af84686d3e34c2f0f9..34854c63936f6ef2a96a0a1cf2f798a370eedd70 100644
--- a/binary_training.py
+++ b/binary_training.py
@@ -36,8 +36,8 @@ def create_ISA_binary_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_tra
     df_val.to_csv('df_preprocessed/df_val_ISA_binary.csv',index=False)
 
 def create_astral_binary_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=1):
-    df_flyer = pd.read_csv('ISA_data/df_flyer_no_miscleavage_astral.csv')
-    df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv')
+    df_flyer = pd.read_csv('ISA_data/datasets/df_flyer_no_miscleavage_astral.csv')
+    df_no_flyer = pd.read_csv('ISA_data/datasets/df_non_flyer_no_miscleavage_astral.csv')
     df_no_flyer['Classes'] = 0
     df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
     df_flyer['Classes'] = 1
@@ -229,6 +229,7 @@ def plot_roc(df,base_path):
 if __name__ == '__main__':
     create_astral_binary_dataset()
     test_data_df_FT, history = main(epoch=150)
+    test_data_df_FT.to_csv('output/binary_astral/Dectetability_prediction_report.csv')
     base_path = 'output/binary_astral'
     plot_and_save_metrics(history,base_path)
     plot_confusion_matrix(test_data_df_FT,base_path)
diff --git a/dataset_extraction.py b/dataset_extraction.py
index ea2d685573f8a8a24369be4d4216881b0ee44a6f..b6610c7117e3ea09eecba0fecb26792d43548dd5 100644
--- a/dataset_extraction.py
+++ b/dataset_extraction.py
@@ -1,3 +1,6 @@
+from cProfile import label
+
+import matplotlib.pyplot as plt
 import pandas as pd
 
 """
@@ -135,6 +138,116 @@ def build_dataset_astral(coverage_treshold = 20, min_peptide = 4, f_name='out_df
     df_final.to_csv('ISA_data/df_flyer_no_miscleavage_astral_15.csv', index=False)
     df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage_astral.csv', index=False)
 
+def build_regression_dataset_astral(coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'):
+
+    df = pd.read_excel('ISA_data/250505_Flyers_ASTRAL_mix_12_species.xlsx')
+    df_non_flyer = pd.read_excel('ISA_data/250505_Non_flyers_ASTRAL_mix_12_species.xlsx')
+    #No flyer
+    df_non_flyer = df_non_flyer[df_non_flyer['Cystein ?']==0]
+    df_non_flyer = df_non_flyer[pd.isna(df_non_flyer['Miscleavage ?'])]
+    df_non_flyer = df_non_flyer[pd.isna(df_non_flyer['MaxLFQ'])]
+    df_non_flyer['Sequences'] = df_non_flyer['Peptide']
+    df_non_flyer['Proteins'] = df_non_flyer['ProteinID']
+    df_non_flyer=df_non_flyer[['Sequences','Proteins']].drop_duplicates()
+    df_non_flyer['Value fragment'] = 0
+    df_non_flyer['Value precursor'] = 0
+    df_non_flyer['Value MaxLFQ'] = 0
+
+
+    #Flyer
+    quantites_table = pd.read_csv('ISA_data/250505_mix_12_souches_lib_12_especes_conta_ASTRAL_BIOASTER_quantities.csv')
+
+    df_filtered = df[~(pd.isna(df['Proteotypic ?']))]
+    df_filtered = df_filtered[df_filtered['Coverage']>=coverage_treshold]
+    df_filtered = df_filtered[pd.isna(df_filtered['Miscleavage ? '])]
+    peptide_count=df_filtered.groupby(["Protein.Names"]).size().reset_index(name='counts')
+    filtered_sequence = peptide_count[peptide_count['counts']>=min_peptide]["Protein.Names"]
+    df_filtered = df_filtered[df_filtered["Protein.Names"].isin(filtered_sequence.to_list())]
+    quantites_table_filtered = quantites_table[quantites_table['Modified.Sequence'].isin(df_filtered['Stripped.Sequence'])]
+    df_filtered = pd.merge(quantites_table_filtered, df_filtered, how='inner',left_on='Modified.Sequence', right_on='Stripped.Sequence')
+    df1_grouped = df_filtered.groupby("Protein.Names")
+    dico_final={}
+    # iterate over each group
+    for group_name, df_group in df1_grouped:
+        seq = df_group['Stripped.Sequence'].to_list()
+        value_maxlfq = df_group['20250129_ISA_MIX-1_48SPD_001'].to_list()
+        value_frag = df_group['Fragment.Quant.Raw'].to_list()
+        value_prec = df_group['Precursor.Quantity'].to_list()
+
+        prot = df_group['Protein.Group'].to_list()[0]
+        max_frag = max(value_frag)
+        max_prec = max(value_prec)
+        max_max_lfq = max(value_maxlfq)
+        for i in range(len(seq)):
+            label_frag = value_frag[i]/max_frag
+            label_prec = value_prec[i] / max_prec
+            label_maxlfq = value_maxlfq[i] / max_max_lfq
+            dico_final[seq[i]] = (prot,label_frag,label_prec,label_maxlfq)
+
+    df_final = pd.DataFrame.from_dict(dico_final, orient='index',columns=['Proteins', 'Value fragment','Value precursor', 'Value MaxLFQ'])
+    df_final['Sequences']=df_final.index
+    df_final = df_final.reset_index()
+    df_final=df_final[['Sequences','Proteins','Value fragment','Value precursor', 'Value MaxLFQ']]
+    # df_final.to_csv('ISA_data/datasets/df_flyer_astral_reg_{}.csv'.format(min_peptide), index=False)
+    # df_non_flyer.to_csv('ISA_data/datasets/df_non_flyer_astral_reg.csv', index=False)
+    return df_final
+
+def build_dataset_regression_zeno(coverage_treshold = 20, min_peptide = 4):
+    df = pd.read_excel('ISA_data/250326_gut_microbiome_std_17_proteomes_data_training_detectability.xlsx')
+    df_non_flyer = pd.read_csv('ISA_data/250422_FASTA_17_proteomes_gut_std_ozyme_+_conta_peptides_digested_filtered.csv')
+    #No flyer
+    df_non_flyer = df_non_flyer[df_non_flyer['Cystein ? ']=='Any']
+    df_non_flyer = df_non_flyer[df_non_flyer['Miscleavage ?'] == 'Any']
+    df_non_flyer = df_non_flyer[df_non_flyer['MaxLFQ'] == 0.0]
+    df_non_flyer['Sequences'] = df_non_flyer['Peptide']
+    df_non_flyer['Proteins'] = df_non_flyer['ProteinID']
+    df_non_flyer=df_non_flyer[['Sequences','Proteins']].drop_duplicates()
+    df_non_flyer['Value fragment']=0
+    df_non_flyer['Value precursor'] =0
+    df_non_flyer['Value MaxLFQ'] =0
+
+
+    #Flyer
+    df_filtered = df[df['Proteotypic ?']=='Proteotypic']
+    df_filtered = df_filtered[df_filtered['Coverage ']>=coverage_treshold]
+    df_filtered = df_filtered[df_filtered['Miscleavage ?']=='Any']
+    df_filtered = df_filtered[~df_filtered['Precursor.Quantity'].isna()]
+    peptide_count=df_filtered.groupby(["Protein.Names"]).size().reset_index(name='counts')
+    filtered_sequence = peptide_count[peptide_count['counts']>=min_peptide]["Protein.Names"]
+    df_filtered = df_filtered[df_filtered["Protein.Names"].isin(filtered_sequence.to_list())]
+
+    df1_grouped = df_filtered.groupby("Protein.Names")
+    dico_final={}
+    # iterate over each group
+    for group_name, df_group in df1_grouped:
+        seq = df_group['Stripped.Sequence'].to_list()
+        value_frag = df_group['Fragment.Quant.Raw'].to_list()
+        value_prec = df_group['Precursor.Quantity'].to_list()
+        value_maxlfq = df_group['MaxLFQ'].to_list()
+
+        prot = df_group['Protein.Group'].to_list()[0]
+        max_frag = max(value_frag)
+        max_prec = max(value_prec)
+        max_max_lfq = max(value_maxlfq)
+        for i in range(len(seq)):
+            label_frag = value_frag[i]/max_frag
+            label_prec = value_prec[i] / max_prec
+            label_maxlfq = value_maxlfq[i] / max_max_lfq
+            dico_final[seq[i]] = (prot,label_frag,label_prec,label_maxlfq)
+
+    df_final = pd.DataFrame.from_dict(dico_final, orient='index',columns=['Proteins', 'Value fragment','Value precursor', 'Value MaxLFQ'])
+    df_final['Sequences']=df_final.index
+    df_final = df_final.reset_index()
+    df_final=df_final[['Sequences','Proteins','Value fragment','Value precursor', 'Value MaxLFQ']]
+    df_final.to_csv('ISA_data/datasets/df_flyer_zeno_reg.csv', index=False)
+    df_non_flyer.to_csv('ISA_data/datasets/df_non_flyer_zeno_reg.csv', index=False)
+
 
 if __name__ == '__main__':
-    build_dataset_astral(coverage_treshold=20, min_peptide=15)
+    df_size=[]
+    for min_pep in range(4,20):
+        df = build_regression_dataset_astral(coverage_treshold=20, min_peptide=min_pep)
+        df_size.append(df.shape[0])
+    plt.clf()
+    plt.bar([i for i in range(4,20)],df_size)
+    plt.savefig('number_of_peptides_thr.png')
diff --git a/main_fine_tune.py b/main_fine_tune.py
index d4ad775974e1d456f2ac2b720c71a2d9aab8b8f2..d58ba437cb859ce4616646896e957a80b300165b 100644
--- a/main_fine_tune.py
+++ b/main_fine_tune.py
@@ -167,7 +167,7 @@ def main():
 
     print('Initialising dataset')
     ## Data init
-    fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_astral_4.csv',
+    fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_combined_4.csv',
                                               val_data_source='df_preprocessed/df_val_astral_multiclass_4.csv',
                                               data_format='csv',
                                               max_seq_len=max_pep_length,
@@ -204,7 +204,7 @@ def main():
 
     history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
                                               validation_data=fine_tune_data.tensor_val_data,
-                                              epochs=450,
+                                              epochs=150,
                                               callbacks=[callback_FT, model_checkpoint_FT])
 
     ## Loading best model weights
@@ -240,18 +240,18 @@ def main():
     report_FT = DetectabilityReport(test_targets_FT_one_hot,
                                     predictions_FT,
                                     test_data_df_FT,
-                                    output_path='./output/report_on_astral_4 (from scratch model categorical train, categorical val )',
+                                    output_path='./output/report_on_astral_4 (Fine tune model (combined 4) categorical train, categorical val )',
                                     history=history_fine_tuned,
                                     rank_by_prot=True,
                                     threshold=None,
                                     name_of_dataset='astral_4 val dataset (Categorical balanced)',
-                                    name_of_model='From scratch model')
+                                    name_of_model='Fine tune model (combined 4)')
 
     report_FT.generate_report()
 
 if __name__ == '__main__':
-    create_astral_dataset(frac_no_fly_val=1)
+    # create_astral_dataset(frac_no_fly_val=1)
     # create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1)
-    create_ISA_dataset(frac_no_fly_val=1)
+    # create_ISA_dataset(frac_no_fly_val=1)
     main()
     # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')
diff --git a/regression_training.py b/regression_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..061b280761073854951e5b6d9e252fbebdce98af
--- /dev/null
+++ b/regression_training.py
@@ -0,0 +1,277 @@
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+import matplotlib.pyplot as plt
+from dlomix.models import DetectabilityModel
+from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict
+from dlomix.data import DetectabilityDataset
+from os.path import join, exists
+from os import makedirs
+from sklearn.metrics import ConfusionMatrixDisplay, auc, confusion_matrix, roc_curve
+
+def create_zeno_reg_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=3,frac_no_fly_val=1):
+    df_flyer = pd.read_csv('ISA_data/datasets/df_flyer_zeno_reg.csv')
+    df_no_flyer = pd.read_csv('ISA_data/datasets/df_non_flyer_zeno_reg.csv')
+    df_flyer['Value']=df_flyer['Value precursor'].apply(np.sqrt)
+    df_no_flyer['Value'] = df_no_flyer['Value precursor'].apply(np.sqrt)
+    df_no_flyer = df_no_flyer[['Sequences', 'Value']]
+    df_flyer = df_flyer[['Sequences','Value']]
+    #stratified split
+    list_train_split=[]
+    list_val_split =[]
+
+    flyer_count = df_flyer.shape[0] #rebalanced flyer and no flyer
+    list_train_split.append(df_flyer.iloc[:int(flyer_count*split[0]),:])
+    list_val_split.append(df_flyer.iloc[int(flyer_count * split[0]):, :])
+
+    list_train_split.append(df_no_flyer.iloc[:int(flyer_count * split[0] * frac_no_fly_train), :])
+    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(flyer_count * split[1] * frac_no_fly_val):, :])
+
+    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
+    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
+
+    df_train['Proteins'] = 0
+    df_val['Proteins'] = 0
+    df_train.to_csv('df_preprocessed/df_train_zeno_reg.csv', index=False)
+    df_val.to_csv('df_preprocessed/df_val_zeno_reg.csv',index=False)
+
+def create_astral_reg_dataset(manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=2,frac_no_fly_val=1):
+    df_flyer = pd.read_csv('ISA_data/datasets/df_flyer_astral_reg_4.csv')
+    df_no_flyer = pd.read_csv('ISA_data/datasets/df_non_flyer_astral_reg.csv')
+    df_flyer['Value'] = df_flyer['Value precursor'].apply(np.sqrt)
+    df_no_flyer['Value'] = df_no_flyer['Value precursor'].apply(np.sqrt)
+    df_no_flyer = df_no_flyer[['Sequences', 'Value']]
+    df_flyer = df_flyer[['Sequences', 'Value']]
+    #stratified split
+    list_train_split=[]
+    list_val_split =[]
+
+    flyer_count = df_flyer.shape[0] #rebalanced flyer and no flyer
+    list_train_split.append(df_flyer.iloc[:int(flyer_count*split[0]),:])
+    list_val_split.append(df_flyer.iloc[int(flyer_count * split[0]):, :])
+
+    list_train_split.append(df_no_flyer.iloc[:int(flyer_count * split[0] * frac_no_fly_train), :])
+    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(flyer_count * split[1] * frac_no_fly_val):, :])
+
+    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed) #shuffle
+    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed) #shuffle
+
+    df_train['Proteins'] = 0
+    df_val['Proteins'] = 0
+    df_train.to_csv('df_preprocessed/df_train_astral_reg.csv', index=False)
+    df_val.to_csv('df_preprocessed/df_val_astral_reg.csv',index=False)
+
+
+def main(epoch):
+    total_num_classes = 2
+    num_cells = 64
+
+    load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL'
+    fine_tuned_model = DetectabilityModel(num_units=num_cells,
+                                          num_clases=1)
+
+    fine_tuned_model.decoder.decoder_dense = tf.keras.layers.Dense(1, activation=None)
+    fine_tuned_model.build((None, 40))
+
+    base_arch  = DetectabilityModel(num_units=num_cells,
+                                          num_clases=4)
+    base_arch.load_weights(load_model_path)
+
+    #partially loading pretrained weights (multiclass training)
+    base_arch.build((None, 40))
+    weights_list = base_arch.get_weights()
+    weights_list[-1] = np.array([0.],dtype=np.float32)
+    weights_list[-2] = np.zeros((128,1),dtype=np.float32)
+    fine_tuned_model.set_weights(weights_list)
+
+
+
+    max_pep_length = 40
+    ## Has no impact for prediction
+    batch_size = 16
+
+    print('Initialising dataset')
+    ## Data init
+    fine_tune_data = DetectabilityDataset(data_source='df_preprocessed/df_train_zeno_reg.csv',
+                                              val_data_source='df_preprocessed/df_val_zeno_reg.csv',
+                                              data_format='csv',
+                                              max_seq_len=max_pep_length,
+                                              label_column="Value",
+                                              sequence_column="Sequences",
+                                              dataset_columns_to_keep=["Proteins"],
+                                              batch_size=batch_size,
+                                              with_termini=False,
+                                              alphabet=aa_to_int_dict)
+
+
+
+
+    # compile the model  with the optimizer and the metrics we want to use.
+    callback_FT = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
+                                                   mode='min',
+                                                   verbose=1,
+                                                   patience=50)
+
+    model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined'
+
+    model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,
+                                                             monitor='val_loss',
+                                                             mode='min',
+                                                             verbose=1,
+                                                             save_best_only=True,
+                                                             save_weights_only=True)
+    opti = tf.keras.optimizers.legacy.Adagrad()
+    fine_tuned_model.compile(optimizer=opti,
+                             loss='MeanAbsoluteError',
+                             metrics='RootMeanSquaredError')
+
+
+
+    history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
+                                              validation_data=fine_tune_data.tensor_val_data,
+                                              epochs=epoch,
+                                              callbacks=[callback_FT, model_checkpoint_FT])
+    ## Loading best model weights
+
+    model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability_combined' #model fined tuned on ISA data
+    # model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
+
+    fine_tuned_model.load_weights(model_save_path_FT)
+
+    predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_val_data)
+
+    # access val dataset and get the Classes column
+    test_targets_FT = fine_tune_data["val"]["Value"]
+
+    # The dataframe needed for the report
+
+    test_data_df_FT = pd.DataFrame(
+        {
+            "Sequences": fine_tune_data["val"]["_parsed_sequence"],  # get the raw parsed sequences
+            "Classes": test_targets_FT,  # get the test targets from above
+            "Proteins": fine_tune_data["val"]["Proteins"],  # get the Proteins column from the dataset object
+            "Prediction": predictions_FT[:,0],
+        }
+    )
+
+
+
+    test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x))
+
+    return test_data_df_FT, history_fine_tuned
+
+def plot_and_save_metrics(history, base_path):
+    history_dict = history.history
+    metrics = history_dict.keys()
+    metrics = filter(lambda x: not x.startswith(tuple(["val_", "_"])), metrics)
+
+    if not exists(base_path):
+        makedirs(base_path)
+
+    for metric_name in metrics:
+        plt.plot(history_dict[metric_name])
+        plt.plot(history_dict["val_" + metric_name])
+        plt.title(metric_name, fontsize=10)  # Modified Original plt.title(metric_name)
+        plt.ylabel(metric_name)
+        plt.xlabel("epoch")
+        plt.legend(["train", "val"], loc="best")
+        save_path = join(base_path, metric_name)
+        plt.savefig(
+            save_path, bbox_inches="tight", dpi=90
+        )  # Modification Original plt.savefig(save_path)
+        plt.close()
+
+def plot_confusion_matrix(df, base_path):
+    conf_matrix = confusion_matrix(
+        df["Classes"],
+        df["Predicted classes"],
+    )
+
+    if not exists(base_path):
+        makedirs(base_path)
+
+    conf_matrix_disp = ConfusionMatrixDisplay(
+        confusion_matrix=conf_matrix, display_labels=["Non-Flyer", "Flyer"]
+    )
+    fig, ax = plt.subplots()
+    conf_matrix_disp.plot(xticks_rotation=45, ax=ax)
+    plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11)
+    save_path = join(base_path, "confusion_matrix_binary"
+    )
+    plt.savefig(save_path, bbox_inches="tight", dpi=80)
+    plt.close()
+
+def plot_roc(df,base_path):
+    fpr, tpr, thresholds = roc_curve(
+        np.array(df["Classes"]),
+        np.array(df["Prediction"]),
+    )
+    AUC_score = auc(fpr, tpr)
+
+    # create ROC curve
+
+    plt.plot(fpr, tpr, label="ROC curve of (area = {})".format(AUC_score))
+    plt.title(
+        "Receiver operating characteristic curve (Binary classification)",
+        y=1.04,
+        fontsize=10,
+    )
+
+    plt.ylabel("True Positive Rate")
+    plt.xlabel("False Positive Rate")
+    save_path = join(
+        base_path, "ROC_curve_binary_classification"
+    )
+
+    plt.savefig(save_path, bbox_inches="tight", dpi=90)
+    plt.close()
+
+def optimize_threshold(df,frac):
+    tr_list=[]
+    acc_list=[]
+    for tr in range(200):
+        df['Binary Classes'] = df['Classes'] != 0
+        df['Binary Prediction'] = df['Prediction'] >= tr/200
+        conf_matrix = confusion_matrix(
+            df["Binary Classes"],
+            df["Binary Prediction"],
+        )
+        tr_list.append(tr/200)
+        acc_list.append((conf_matrix[0,0]+conf_matrix[1,1])/conf_matrix.sum())
+    plt.clf()
+    plt.plot(tr_list,acc_list)
+    plt.savefig('acc_tuning_zeno_sqrt_mae{}.png'.format(frac))
+    plt.clf()
+    print(max(acc_list))
+    #best conf matrix
+    tr = tr_list[np.argmax(acc_list)]
+    df['Binary Classes'] = df['Classes'] != 0
+    df['Binary Prediction'] = df['Prediction'] >= tr
+    conf_matrix = confusion_matrix(
+        df["Binary Classes"],
+        df["Binary Prediction"],
+    )
+
+    conf_matrix_disp = ConfusionMatrixDisplay(
+        confusion_matrix=conf_matrix, display_labels=["Non-Flyer", "Flyer"]
+    )
+    fig, ax = plt.subplots()
+    conf_matrix_disp.plot(xticks_rotation=45, ax=ax)
+    plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11)
+    save_path = join('output', "confusion_matrix_zeno_reg_sqrt_mae{}".format(frac)
+                     )
+    plt.savefig(save_path, bbox_inches="tight", dpi=80)
+    plt.close()
+
+    return tr,acc_list
+
+
+if __name__ == '__main__':
+    for f in [1,1.5]:
+        create_zeno_reg_dataset(frac_no_fly_train=f)
+        test_data_df_FT, history = main(epoch=150)
+        optimize_threshold(test_data_df_FT,f)
+    # base_path = 'output/binary_astral'
+    # plot_and_save_metrics(history,base_path)
+    # plot_confusion_matrix(test_data_df_FT,base_path)
+    # plot_roc(test_data_df_FT,base_path)
\ No newline at end of file
diff --git a/threshold_opti_categorical.py b/threshold_opti_categorical.py
new file mode 100644
index 0000000000000000000000000000000000000000..588d2e4e08400f9b499e08ca28f12326c79f10c3
--- /dev/null
+++ b/threshold_opti_categorical.py
@@ -0,0 +1,80 @@
+import pandas as pd
+import matplotlib.pyplot as plt
+from sklearn.metrics import ConfusionMatrixDisplay, auc, confusion_matrix, roc_curve
+import numpy as np
+
+def optimize_tr_categorical(df_path):
+    df = pd.read_csv(df_path)
+    df['Bool Classes']=df['Binary Classes']!='Non-Flyer'
+    tr_list = []
+    acc_list =[]
+    for tr in range(200):
+        df['Opti Classes']=df['Non-Flyer']<=tr/200
+        conf_matrix = confusion_matrix(
+            df["Bool Classes"],
+            df["Opti Classes"]
+        )
+        tr_list.append(tr/200)
+        acc_list.append((conf_matrix[0,0]+conf_matrix[1,1])/conf_matrix.sum())
+    plt.clf()
+    plt.plot(tr_list,acc_list)
+    plt.legend(['max acc : {}'.format(max(acc_list))])
+    plt.savefig('acc_opti_zeno.png')
+    plt.clf()
+    tr = tr_list[np.argmax(acc_list)]
+
+    df['Opti Classes']=df['Non-Flyer']<=tr
+    conf_matrix = confusion_matrix(
+        df["Bool Classes"],
+        df["Opti Classes"]
+    )
+
+    conf_matrix_disp = ConfusionMatrixDisplay(
+        confusion_matrix=conf_matrix, display_labels=["Non-Flyer","Flyer"]
+    )
+    fig, ax = plt.subplots()
+    conf_matrix_disp.plot(xticks_rotation=45, ax=ax)
+    plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11)
+    plt.savefig('best_conf_matrix_zeno_4.png', bbox_inches="tight", dpi=80)
+    plt.close()
+
+
+def optimize_tr_binary(df_path):
+    df = pd.read_csv(df_path)
+    df['Bool Classes']=df['Classes']!=0
+    tr_list = []
+    acc_list =[]
+    for tr in range(200):
+        df['Opti Classes']=df['Prob non flyer']<=tr/200
+        conf_matrix = confusion_matrix(
+            df["Bool Classes"],
+            df["Opti Classes"]
+        )
+        tr_list.append(tr/200)
+        acc_list.append((conf_matrix[0,0]+conf_matrix[1,1])/conf_matrix.sum())
+    plt.clf()
+    plt.plot(tr_list,acc_list)
+    plt.legend(['max acc : {}'.format(max(acc_list))])
+    plt.savefig('acc_opti_binary_astral.png')
+    plt.clf()
+    tr = tr_list[np.argmax(acc_list)]
+
+    df['Opti Classes']=df['Prob non flyer']<=tr
+    conf_matrix = confusion_matrix(
+        df["Bool Classes"],
+        df["Opti Classes"]
+    )
+
+    conf_matrix_disp = ConfusionMatrixDisplay(
+        confusion_matrix=conf_matrix, display_labels=["Non-Flyer","Flyer"]
+    )
+    fig, ax = plt.subplots()
+    conf_matrix_disp.plot(xticks_rotation=45, ax=ax)
+    plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11)
+    plt.savefig('best_conf_matrix_binary_astral.png', bbox_inches="tight", dpi=80)
+    plt.close()
+
+if __name__ == '__main__':
+    # optimize_tr_categorical(
+    #     'output/report_on_ISA (Fine tune model categorical train, binary val )/Dectetability_prediction_report.csv')
+    optimize_tr_binary('output/binary_astral/Dectetability_prediction_report.csv')
\ No newline at end of file