diff --git a/dataset_comparison.py b/dataset_comparison.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e6376399c99bb808eb698f284395b0fe2503254
--- /dev/null
+++ b/dataset_comparison.py
@@ -0,0 +1,64 @@
+import pandas as pd
+from datasets import load_dataset, DatasetDict
+
+
+df_list =["Wilhelmlab/detectability-proteometools", "Wilhelmlab/detectability-wang","Wilhelmlab/detectability-sinitcyn"]
+df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv')
+df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
+
+for label_type in ['Classes fragment','Classes precursor', 'Classes MaxLFQ'] :
+    df_full = pd.concat([df_flyer,df_no_flyer])
+    df_size = df_full.shape[0]
+    nb_no_flyer = df_full[df_full[label_type]==0].shape[0]
+    nb_weak_flyer = df_full[df_full[label_type] == 1].shape[0]
+    nb_intermediate_flyer = df_full[df_full[label_type] == 2].shape[0]
+    nb_strong_flyer = df_full[df_full[label_type] == 3].shape[0]
+    print('df ISA {} class repartition : No flyer {:.2f}%, Weak flyer {:.2f}%, Intermediate flyer {:.2f}%, Strong flyer {:.2f}%'.format(label_type,100*nb_no_flyer/df_size,100*nb_weak_flyer/df_size,100*nb_intermediate_flyer/df_size,100*nb_strong_flyer/df_size))
+
+l_inter_ISA=[]
+l_df_hg=[]
+for hf_data_name in df_list :
+
+    hf_dataset_split = load_dataset(hf_data_name)
+    l = [pd.DataFrame(hf_dataset_split[k]) for k in hf_dataset_split.keys()]
+    df_hg = pd.concat(l)
+
+    df_size = df_hg.shape[0]
+    nb_no_flyer = df_hg[df_hg['Classes']==0].shape[0]
+    nb_weak_flyer = df_hg[df_hg['Classes'] == 1].shape[0]
+    nb_intermediate_flyer = df_hg[df_hg['Classes'] == 2].shape[0]
+    nb_strong_flyer = df_hg[df_hg['Classes'] == 3].shape[0]
+    print('df {} class repartition : No flyer {:.2f}%, Weak flyer {:.2f}%, Intermediate flyer {:.2f}%, Strong flyer {:.2f}%'.format(hf_data_name,100*nb_no_flyer/df_size,100*nb_weak_flyer/df_size,100*nb_intermediate_flyer/df_size,100*nb_strong_flyer/df_size))
+
+    df_common = df_hg.join(df_full.set_index('Sequences'),on='Sequences',how='inner',lsuffix='_hg',rsuffix='_ISA')
+    size_inter = df_common.shape[0]
+    same_label = df_common[df_common['Classes']==df_common['Classes MaxLFQ']].shape[0]
+    l_inter_ISA.append(df_common)
+    print('Inter with ISA df size : {}, similar label : {:.2f}%'.format(size_inter,100*same_label/size_inter))
+
+    for df_hg_bis in l_df_hg :
+        df_common = df_hg.join(df_hg_bis.set_index('Sequences'), on='Sequences', how='inner', lsuffix='_hg',
+                               rsuffix='_hg_bis')
+        size_inter = df_common.shape[0]
+        same_label = df_common[df_common['Classes_hg'] == df_common['Classes_hg_bis']]
+        same_label_size = same_label.shape[0]
+        cf_matrix = pd.crosstab(df_common['Classes_hg'], df_common['Classes_hg_bis'])
+        print('Inter with df hg bis df size : {}, similar label : {:.2f}%'.format(size_inter, 100 * same_label_size / size_inter))
+        print(cf_matrix)
+    l_df_hg.append(df_hg)
+
+
+
+
+
+
+
+#
+# target_labels = {
+#     0: "Non-Flyer",
+#     1: "Weak Flyer",
+#     2: "Intermediate Flyer",
+#     3: "Strong Flyer",
+# }
+# binary_labels = {0: "Non-Flyer", 1: "Flyer"}
+
diff --git a/dataset_extraction.py b/dataset_extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7b8e1802cebb34fb971800228e5b4076f4535bb
--- /dev/null
+++ b/dataset_extraction.py
@@ -0,0 +1,86 @@
+import pandas as pd
+
+"""
+target_labels = {
+    0: "Non-Flyer",
+    1: "Weak Flyer",
+    2: "Intermediate Flyer",
+    3: "Strong Flyer",
+}
+binary_labels = {0: "Non-Flyer", 1: "Flyer"}
+"""
+
+
+def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, min_peptide = 4, f_name='out_df.csv'):
+    df = pd.read_excel('ISA_data/250326_gut_microbiome_std_17_proteomes_data_training_detectability.xlsx')
+    df_non_flyer = pd.read_csv('ISA_data/250422_FASTA_17_proteomes_gut_std_ozyme_+_conta_peptides_digested_filtered.csv')
+    #No flyer
+    df_non_flyer = df_non_flyer[df_non_flyer['Cystein ? ']=='Any']
+    df_non_flyer = df_non_flyer[df_non_flyer['Miscleavage ?'] == 'Any']
+    df_non_flyer = df_non_flyer[df_non_flyer['MaxLFQ'] == 0.0]
+    df_non_flyer['Sequences'] = df_non_flyer['Peptide']
+    df_non_flyer['Proteins'] = df_non_flyer['ProteinID']
+    df_non_flyer=df_non_flyer[['Sequences','Proteins']].drop_duplicates()
+    df_non_flyer['Classes fragment']=0
+    df_non_flyer['Classes precursor'] =0
+    df_non_flyer['Classes MaxLFQ'] =0
+
+
+    #Flyer
+    df_filtered = df[df['Proteotypic ?']=='Proteotypic']
+    df_filtered = df_filtered[df_filtered['Coverage ']>=coverage_treshold]
+    df_filtered = df_filtered[df_filtered['Miscleavage ?']=='Any']
+    peptide_count=df_filtered.groupby(["Protein.Names"]).size().reset_index(name='counts')
+    filtered_sequence = peptide_count[peptide_count['counts']>=min_peptide]["Protein.Names"]
+    df_filtered = df_filtered[df_filtered["Protein.Names"].isin(filtered_sequence.to_list())]
+
+    df1_grouped = df_filtered.groupby("Protein.Names")
+    dico_final={}
+    # iterate over each group
+    for group_name, df_group in df1_grouped:
+        seq = df_group.sort_values(by=[intensity_col])['Stripped.Sequence'].to_list()
+        value_frag = df_group.sort_values(by=[intensity_col])[intensity_col].to_list()
+        value_prec = df_group.sort_values(by=['Precursor.Quantity'])['Precursor.Quantity'].to_list()
+        value_prec_frag = df_group.sort_values(by=[intensity_col])['Precursor.Quantity'].to_list()
+        value_maxlfq = df_group.sort_values(by=['MaxLFQ'])['MaxLFQ'].to_list()
+        value_maxlfq_frag = df_group.sort_values(by=[intensity_col])['MaxLFQ'].to_list()
+        threshold_weak_flyer_frag = value_frag[int(len(seq) / 3)]
+        threshold_medium_flyer_frag = value_frag[int(2*len(seq) / 3)]
+        threshold_weak_flyer_prec = value_prec[int(len(seq) / 3)]
+        threshold_medium_flyer_prec = value_prec[int(2 * len(seq) / 3)]
+        threshold_weak_flyer_maxflq = value_maxlfq[int(len(seq) / 3)]
+        threshold_medium_flyer_maxlfq = value_maxlfq[int(2 * len(seq) / 3)]
+        prot = df_group['Protein.Group'].to_list()[0]
+
+        for i in range(len(seq)):
+            if value_frag[i] < threshold_weak_flyer_frag :
+                label_frag = 1
+            elif value_frag[i] < threshold_medium_flyer_frag :
+                label_frag = 2
+            else :
+                label_frag = 3
+
+            if value_prec_frag[i] < threshold_weak_flyer_prec :
+                label_prec = 1
+            elif value_prec_frag[i] < threshold_medium_flyer_prec :
+                label_prec = 2
+            else :
+                label_prec = 3
+
+            if value_maxlfq_frag[i] < threshold_weak_flyer_maxflq :
+                label_maxlfq = 1
+            elif value_maxlfq_frag[i] < threshold_medium_flyer_maxlfq :
+                label_maxlfq = 2
+            else :
+                label_maxlfq = 3
+
+            dico_final[seq[i]] = (prot,label_frag,label_prec,label_maxlfq)
+
+    df_final = pd.DataFrame.from_dict(dico_final, orient='index',columns=['Proteins', 'Classes fragment','Classes precursor', 'Classes MaxLFQ'])
+    df_final['Sequences']=df_final.index
+    df_final = df_final.reset_index()
+    df_final=df_final[['Sequences','Proteins','Classes fragment','Classes precursor', 'Classes MaxLFQ']]
+    df_final.to_csv(f_name, index=False)
+    df_non_flyer.to_csv('ISA_data/df_non_flyer_no_miscleavage.csv', index=False)
+if __name__ == '__main__':
+    build_dataset( coverage_treshold=20, min_peptide=4, f_name='ISA_data/df_finetune_no_miscleavage.csv')
diff --git a/main.py b/main.py
index 76475bd01ab5d0f7292deb7b8e803eb9bf6a86f5..08fe03cd7cd0279673f6ef56a72b679c5db865b9 100644
--- a/main.py
+++ b/main.py
@@ -165,4 +165,4 @@ def main(input_data_path):
 
 if __name__ == '__main__':
     # main()
-    main('241205_list_test_peptide_detectability.txt')
+    main('output/241205_list_test_peptide_detectability.txt')
diff --git a/main_fine_tune.py b/main_fine_tune.py
new file mode 100644
index 0000000000000000000000000000000000000000..e295ebe7693d2ec667a3860d5af2418c99601833
--- /dev/null
+++ b/main_fine_tune.py
@@ -0,0 +1,185 @@
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+import matplotlib.pyplot as plt
+from dlomix.models import DetectabilityModel
+from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict
+from dlomix.data import DetectabilityDataset
+from datasets import load_dataset, DatasetDict
+from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report
+
+
+def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0.8,0.2),frac_no_fly_train=1,frac_no_fly_val=2):
+    df_flyer = pd.read_csv('ISA_data/df_finetune_no_miscleavage.csv')
+    df_no_flyer = pd.read_csv('ISA_data/df_non_flyer_no_miscleavage.csv')
+    df_no_flyer['Classes'] = df_no_flyer[classe_type]
+    df_no_flyer = df_no_flyer[['Sequences', 'Classes']]
+    df_flyer['Classes']=df_flyer[classe_type]
+    df_flyer = df_flyer[['Sequences','Classes']]
+    #stratified split
+    list_train_split=[]
+    list_val_split =[]
+    for cl in [1,2,3]:
+        df_class = df_flyer[df_flyer['Classes']==cl]
+        class_count = df_class.shape[0]
+        list_train_split.append(df_class.iloc[:int(class_count*split[0]),:])
+        list_val_split.append(df_class.iloc[int(class_count * split[0]):, :])
+
+    list_train_split.append(df_no_flyer.iloc[:int(class_count * split[0] * frac_no_fly_train), :])
+    list_val_split.append(df_no_flyer.iloc[df_no_flyer.shape[0]-int(class_count * split[1] * frac_no_fly_val):, :])
+
+    df_train = pd.concat(list_train_split).sample(frac=1, random_state=manual_seed)
+    df_val = pd.concat(list_val_split).sample(frac=1, random_state=manual_seed)
+
+    df_train['Proteins']=0
+    df_val['Proteins'] = 0
+    df_train.to_csv('temp_fine_tune_df_train.csv', index=False)
+    df_val.to_csv('temp_fine_tune_df_val.csv',index=False)
+
+def density_plot(prediction_path,prediction_path_2,criteria='base'):
+    df = pd.read_csv(prediction_path)
+    df['actual_metrics']=-df['Non-Flyer']+df[["Weak Flyer", "Strong Flyer",'Intermediate Flyer']].max(axis=1)
+    flyer = df[df['Binary Classes']=='Flyer']
+    non_flyer = df[df['Binary Classes'] == 'Non-Flyer']
+
+    df2 = pd.read_csv(prediction_path_2)
+    df2['actual_metrics'] = -df2['Non-Flyer'] + df2[["Weak Flyer", "Strong Flyer", 'Intermediate Flyer']].max(axis=1)
+    flyer2 = df2[df2['Binary Classes'] == 'Flyer']
+    non_flyer2 = df2[df2['Binary Classes'] == 'Non-Flyer']
+
+
+    non_flyer['Flyer'].plot.density(bw_method=0.2, color='blue', linestyle='-', linewidth=2)
+    flyer['Flyer'].plot.density(bw_method=0.2, color='red', linestyle='-', linewidth=2)
+    non_flyer2['Flyer'].plot.density(bw_method=0.2, color='yellow', linestyle='-', linewidth=2)
+    flyer2['Flyer'].plot.density(bw_method=0.2, color='green', linestyle='-', linewidth=2)
+    plt.legend(['non-flyer base','flyer base','non-flyer fine tuned','flyer fine tuned'])
+    plt.xlabel('Flyer index')
+    plt.xlim(0,1)
+    plt.show()
+    plt.savefig('density_sum.png')
+    plt.clf()
+    #Weak Flyer,Intermediate Flyer,Strong Flyer
+
+    non_flyer['actual_metrics'].plot.density(bw_method=0.2, color='blue', linestyle='-', linewidth=2)
+    flyer['actual_metrics'].plot.density(bw_method=0.2, color='red', linestyle='-', linewidth=2)
+    non_flyer2['actual_metrics'].plot.density(bw_method=0.2, color='yellow', linestyle='-', linewidth=2)
+    flyer2['actual_metrics'].plot.density(bw_method=0.2, color='green', linestyle='-', linewidth=2)
+    plt.legend(['non-flyer base','flyer base','non-flyer fine tuned','flyer fine tuned'])
+    plt.xlabel('Flyer index')
+    plt.xlim(-1,1)
+    plt.show()
+    plt.savefig('density_base.png')
+
+
+
+
+
+def main():
+    total_num_classes = len(CLASSES_LABELS)
+    input_dimension = len(alphabet)
+    num_cells = 64
+
+    fine_tuned_model = DetectabilityModel(num_units=num_cells, num_clases=total_num_classes)
+    load_model_path = 'pretrained_model/original_detectability_fine_tuned_model_FINAL'
+    fine_tuned_model = DetectabilityModel(num_units=num_cells,
+                                          num_clases=total_num_classes)
+
+    fine_tuned_model.load_weights(load_model_path)
+
+
+
+    max_pep_length = 40
+    ## Has no impact for prediction
+    batch_size = 16
+
+    print('Initialising dataset')
+    ## Data init
+    fine_tune_data = DetectabilityDataset(data_source='temp_fine_tune_df_train.csv',
+                                              val_data_source='temp_fine_tune_df_val.csv',
+                                              data_format='csv',
+                                              max_seq_len=max_pep_length,
+                                              label_column="Classes",
+                                              sequence_column="Sequences",
+                                              dataset_columns_to_keep=["Proteins"],
+                                              batch_size=batch_size,
+                                              with_termini=False,
+                                              alphabet=aa_to_int_dict)
+
+
+
+
+    # compile the model  with the optimizer and the metrics we want to use.
+    callback_FT = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
+                                                   mode='min',
+                                                   verbose=1,
+                                                   patience=5)
+
+    model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'
+
+    model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,
+                                                             monitor='val_loss',
+                                                             mode='min',
+                                                             verbose=1,
+                                                             save_best_only=True,
+                                                             save_weights_only=True)
+    opti = tf.keras.optimizers.legacy.Adagrad()
+    fine_tuned_model.compile(optimizer=opti,
+                             loss='SparseCategoricalCrossentropy',
+                             metrics='sparse_categorical_accuracy')
+
+
+
+    history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,
+                                              validation_data=fine_tune_data.tensor_val_data,
+                                              epochs=1,
+                                              callbacks=[callback_FT, model_checkpoint_FT])
+
+    ## Loading best model weights
+
+    # model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'
+    model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
+
+    fine_tuned_model.load_weights(model_save_path_FT)
+
+    predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_val_data)
+
+    # access val dataset and get the Classes column
+    test_targets_FT = fine_tune_data["val"]["Classes"]
+
+    # if needed, the decoded version of the classes can be retrieved by looking up the class names
+    test_targets_decoded_FT = [CLASSES_LABELS[x] for x in test_targets_FT]
+
+    # The dataframe needed for the report
+
+    test_data_df_FT = pd.DataFrame(
+        {
+            "Sequences": fine_tune_data["val"]["_parsed_sequence"],  # get the raw parsed sequences
+            "Classes": test_targets_FT,  # get the test targets from above
+            "Proteins": fine_tune_data["val"]["Proteins"]  # get the Proteins column from the dataset object
+        }
+    )
+
+    test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: "".join(x))
+
+    # Since the detectabiliy report expects the true labels in one-hot encoded format, we expand them here.
+
+    num_classes = np.max(test_targets_FT) + 1
+    test_targets_FT_one_hot = np.eye(num_classes)[test_targets_FT]
+    test_targets_FT_one_hot.shape, len(test_targets_FT)
+
+    report_FT = DetectabilityReport(test_targets_FT_one_hot,
+                                    predictions_FT,
+                                    test_data_df_FT,
+                                    output_path='./output/report_on_ISA (Base model categorical train, binary val )',
+                                    history=history_fine_tuned,
+                                    rank_by_prot=True,
+                                    threshold=None,
+                                    name_of_dataset='ISA val dataset (binary balanced)',
+                                    name_of_model='Base model (ISA)')
+
+    report_FT.generate_report()
+
+if __name__ == '__main__':
+    create_ISA_dataset()
+    main()
+    # density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')