import os

import fastapy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib_venn import venn2
from analyse_diann_digestion import load_lib
import matplotlib.image as mpimg
import re

# SPECIES = [f.name for f in os.scandir('../data/processed_data_wiff/npy_image/all data') if f.is_dir()]
SPECIES = ['Enterobacter bugandensis', 'Enterobacter hormaechei', 'Enterobacter cloacae',
           'Serratia marcescens', 'Morganella morganii', 'Citrobacter freundii', 'Hafnia alvei',
           'Citrobacter amalonaticus', 'Salmonella enterica', 'Klebsiella quasipneumoniae',
           'Raoultella ornithinolytica', 'Escherichia coli', 'Hafnia paralvei', 'Citrobacter portucalensis',
           'Proteus penneri', 'Providencia stuartii',
           'Klebsiella michiganensis', 'Klebsiella variicola', 'Proteus terrae', 'Pantoea septica',
           'Proteus vulgaris', 'Proteus mirabilis', 'Klebsiella oxytoca',
           'Klebsiella aerogenes', 'Proteus columbae',
           'Citrobacter koseri', 'Enterobacter kobei',
           'Klebsiella pneumoniae', 'Pantoea agglomerans', 'Klebsiella grimontii',
           'Providencia rettgeri',
           'Citrobacter farmeri', 'Enterobacter chengduensis', 'Enterobacter ludwigii',
           'Enterobacter asburiae', 'Citrobacter braakii', 'Enterobacter roggenkampii', 'Citrobacter cronae']


# trypsin cut after K or R (if not followed by P)

def cut(seq):
    cuts = []
    l = len(seq)
    for i in range(l):
        if seq[i] == 'R' or seq[i] == 'K':
            if i < l - 1 and seq[i + 1] != 'P':
                cuts.append(i + 1)
    return cuts


def cut_with_ind(seq, ind_list):
    l = []
    size = len(seq)
    ind_list.append(size)

    for i in range(len(ind_list) - 1):
        if i == 0:
            l.append(seq[:ind_list[i]])
        l.append(seq[ind_list[i]:ind_list[i + 1]])

    return l


def digest(seq):
    ind = cut(seq)
    return cut_with_ind(seq, ind)


def fasta_similarity(path_fasta_1, path_fasta_2):
    list_seq_1 = []
    list_seq_2 = []
    for record in fastapy.parse(path_fasta_1):
        list_seq_1.extend(digest(record.seq))
    for record in fastapy.parse(path_fasta_2):
        list_seq_2.extend(digest(record.seq))

    set1 = set(list_seq_1)
    set2 = set(list_seq_2)

    venn2((set1, set2), ('Group1', 'Group2'))
    plt.show()
    plt.savefig('fasta_similarity.png')


def split_string(input_string):
    # Use regular expression to split the string at underscore followed by uppercase letter
    return re.split(r'_(?=[A-Z])|c(?=[A-Z])', input_string)


def build_database_ref_peptide():
    l = []
    with open('../data/label_raw/250107_FASTA_RP_GroEL_GroES_Tuf_5pct_assemble_peptides_list.txt', 'r') as f:
        for line in f:
            if line != '\n':
                if '>' in line:
                    # remove new line
                    line = line.replace('\n', '')

                    # typo ??
                    line = line.replace('no_family', 'No_family')
                    line = line.replace('no_order', 'No_order')

                    split_line = line.split('_')
                    prot = split_line[0][1:]
                    err = split_line[1]
                    prev = split_line[2]
                    split_line = split_string(line.split(' ')[1])
                    spe = split_line[0].replace('_', ' ')
                    gen = split_line[1].replace('_', ' ')
                    fam = split_line[2].replace('_', ' ')
                    o = split_line[3].replace('_', ' ')
                else:
                    seq = line.split(' ')[1]
                    # remove new lines
                    seq = seq.replace('\n', '')
                    l.append({'Sequence': seq, 'Protein code': prot, 'Error treshold': err, 'Prevalance': prev,
                              'Specie': spe, 'Genus': gen, 'Family': fam, 'Order': o})
    df = pd.DataFrame(l)
    return df.drop_duplicates()


def compute_common_peptide(df_path, species):
    df = pd.read_csv(df_path)
    df_filter = df[df['Specie'].isin(species)]
    df_filter = df_filter[['Sequence', 'Specie']].drop_duplicates()
    df_filter['Count'] = df_filter.groupby('Sequence')['Sequence'].transform('count')
    df_count = df_filter[['Sequence', 'Count']].drop_duplicates()
    return df_count


def build_ref_image_from_diann(path_parqet, ms1_end_mz, ms1_start_mz, bin_mz, max_cycle, min_rt=None, max_rt=None):
    df = load_lib(path_parqet)
    df = df[['Stripped.Sequence', 'Precursor.Charge', 'RT', 'Precursor.Mz']]
    df_unique = df.drop_duplicates()
    # build image
    total_ms1_mz = ms1_end_mz - ms1_start_mz
    n_bin_ms1 = int(total_ms1_mz // bin_mz)
    im = np.zeros([max_cycle, n_bin_ms1])
    if max_rt is None:
        max_rt = np.max(df_unique['RT'])
    if min_rt is None:
        min_rt = np.min(df_unique['RT'])
    total_rt = max_rt - min_rt + 1e-3
    for row in df_unique.iterrows():
        if 900 > int(((row[1]['Precursor.Mz'] - ms1_start_mz) / total_ms1_mz) * n_bin_ms1) >= 0:
            im[int((row[1]['RT'] - min_rt) / total_rt * max_cycle), int(
                ((row[1]['Precursor.Mz'] - ms1_start_mz) / total_ms1_mz) * n_bin_ms1)] = 1

    return im


def build_ref_image_from_diann_global(path_parqet, target_seq, ms1_end_mz, ms1_start_mz, bin_mz, max_cycle, min_rt=None,
                                      max_rt=None):
    df = load_lib(path_parqet)
    df = df[['Stripped.Sequence', 'Precursor.Charge', 'RT', 'Precursor.Mz']]
    df = df[df['Stripped.Sequence'].isin(target_seq)]
    df_unique = df.drop_duplicates()
    # build image
    total_ms1_mz = ms1_end_mz - ms1_start_mz
    n_bin_ms1 = int(total_ms1_mz // bin_mz)
    im = np.zeros([max_cycle, n_bin_ms1])
    if max_rt is None:
        max_rt = np.max(df_unique['RT'])
    if min_rt is None:
        min_rt = np.min(df_unique['RT'])
    total_rt = max_rt - min_rt + 1e-3
    for row in df_unique.iterrows():
        if 900 > int(((row[1]['Precursor.Mz'] - ms1_start_mz) / total_ms1_mz) * n_bin_ms1) >= 0:
            im[int((row[1]['RT'] - min_rt) / total_rt * max_cycle), int(
                ((row[1]['Precursor.Mz'] - ms1_start_mz) / total_ms1_mz) * n_bin_ms1)] = 1

    return im


if __name__ == '__main__':
    # df = build_database_ref_peptide()
    # df.to_csv("dataset_species_ref_peptides.csv", index=False)
    # Write fasta file
    # with open('fasta/optimal peptide set/gobal_peptide_set.fasta', "w") as f:
    #     df_spe = df[df['Specie'].isin(SPECIES)]
    #     spe_list = df_spe['Sequence'].drop_duplicates().to_list()
    #
    #     for pep in spe_list :
    #         print(pep)
    #         f.write(pep+'\n')
    #
    df_count = compute_common_peptide("dataset_species_ref_peptides.csv", SPECIES)

    #
    # Create ref img
    df_full = load_lib(
        'fasta/full proteom/steigerwaltii variants/uniparc_proteome_UP000033376_2025_03_14.predicted.parquet')
    min_rt = df_full['RT'].min()
    max_rt = df_full['RT'].max()
    #
    df = pd.read_csv("dataset_species_ref_peptides.csv")
    #
    for spe in SPECIES:
        print(spe)
        df_spe = df[df['Specie'] == spe]
        df_spec_no_common = df_spe[df_spe['Sequence'].isin(df_count[df_count['Count']<5]['Sequence'])]
        im = build_ref_image_from_diann_global(
            'fasta/global_peptide_list.parquet', target_seq=df_spec_no_common['Sequence'].to_list(), ms1_end_mz=1250,
            ms1_start_mz=350, bin_mz=1, max_cycle=663, min_rt=min_rt, max_rt=max_rt)
        plt.clf()
        mpimg.imsave('img_ref_common_th_5/' + spe + '.png', im)
        np.save('img_ref_common_th_5/' + spe + '.npy', im)