import fastapy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib_venn import venn2
from analyse_diann_digestion import load_lib
import matplotlib.image as mpimg
import re

ALPHABET_UNMOD = {
    "A": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "K": 9,
    "L": 10,
    "M": 11,
    "N": 12,
    "P": 13,
    "Q": 14,
    "R": 15,
    "S": 16,
    "T": 17,
    "V": 18,
    "W": 19,
    "Y": 20,
}


MASSES_MONO = {
    "A": 71.03711,
    "C": 103.00919,
    "D": 115.02694,
    "E": 129.04259,
    "F": 147.06841,
    "G": 57.02146,
    "H": 137.05891,
    "I": 113.08406,
    "K": 128.09496,
    "L": 113.08406,
    "M": 131.04049,
    "N": 114.04293,
    "P": 97.05276,
    "Q": 128.05858,
    "R": 156.1875,
    "S": 87.03203,
    "T": 101.04768,
    "V": 99.06841,
    "W": 186.07931,
    "Y": 163.06333,
}


MASSES_AVG = {
    "A": 71.0788,
    "C": 103.1388,
    "D": 115.0886,
    "E": 129.1155,
    "F": 147.1766,
    "G": 57.0519,
    "H": 137.1411,
    "I": 113.1594,
    "K": 128.1741,
    "L": 113.1594,
    "M": 131.1926,
    "N": 114.1038,
    "P": 97.1167,
    "Q": 128.1307,
    "R": 156.1875,
    "S": 87.0782,
    "T": 101.1051,
    "V": 99.1326,
    "W": 186.2132,
    "Y": 163.1760,
}

# trypsin cut after K or R (if not followed by P)

def cut(seq):
    cuts = []
    l = len(seq)
    for i in range(l):
        if seq[i] == 'R' or seq[i] == 'K':
            if i < l - 1 and seq[i + 1] != 'P':
                cuts.append(i + 1)
    return cuts


def cut_with_ind(seq, ind_list):
    l = []
    size = len(seq)
    ind_list.append(size)

    for i in range(len(ind_list) - 1):
        if i == 0:
            l.append(seq[:ind_list[i]])
        l.append(seq[ind_list[i]:ind_list[i + 1]])

    return l

def digest(seq):
    ind = cut(seq)
    return cut_with_ind(seq, ind)

def fasta_similarity(path_fasta_1, path_fasta_2):
    list_seq_1=[]
    list_seq_2 = []
    for record in fastapy.parse(path_fasta_1):
        list_seq_1.extend(digest(record.seq))
    for record in fastapy.parse(path_fasta_2):
        list_seq_2.extend(digest(record.seq))

    set1 = set(list_seq_1)
    set2 = set(list_seq_2)

    venn2((set1, set2), ('Group1', 'Group2'))
    plt.show()
    plt.savefig('fasta_similarity.png')

def split_string(input_string):
    # Use regular expression to split the string at underscore followed by uppercase letter
    return re.split(r'_(?=[A-Zc])', input_string)

def build_database_ref_peptide():
    l=[]
    with open('../data/label_raw/250107_FASTA_RP_GroEL_GroES_Tuf_5pct_assemble_peptides_list.txt', 'r') as f:
        for line in f:
            if line != '\n':
                if '>' in line:
                    #typo ??
                    line = line.replace('no_family','No_family')
                    line = line.replace('no_order', 'No_order')

                    split_line = line.split('_')
                    prot = split_line[0][1:]
                    err = split_line[1]
                    prev = split_line[2]
                    split_line = split_string(line.split(' ')[1])
                    spe = split_line[0].replace('_',' ')
                    gen = split_line[1].replace('_',' ')
                    fam = split_line[2].replace('_',' ')
                    o = split_line[3].replace('_',' ')
                else :
                    seq = line.split(' ')[1]
                    l.append({'Sequence' : seq,'Protein code' :prot , 'Error treshold':err , 'Prevalance': prev,
                              'Specie':spe ,'Genus':gen ,'Family':fam ,'Order':o })
    return pd.DataFrame(l)

def compute_mass(seq, isotop):
    m = 0
    if isotop == 'mono':
        for char in MASSES_MONO.keys():
            m += MASSES_MONO[char] * seq.count(char)
    if isotop == 'avg':
        for char in MASSES_AVG.keys():
            m += MASSES_AVG[char] * seq.count(char)
    return m

def build_ref_image(path_fasta, possible_charge, ms1_end_mz, ms1_start_mz, bin_mz, max_cycle, rt_pred):

    #build peptide list
    list_seq_1 = []
    for record in fastapy.parse(path_fasta):
        list_seq_1.append(record.seq)

    list_peptides = []
    for prot in list_seq_1 :
        list_peptides.extend(digest(prot))

    #compute m/z ration
    mz_ratio={}
    i=0
    list_peptides = list(set(list_peptides))
    for seq in list_peptides:
        mz_ratio['seq']=[]
        for charge in possible_charge:
            ratio = compute_mass(seq,'avg')/charge
            if ms1_end_mz > ratio > ms1_start_mz:
                mz_ratio['seq'].append(ratio)
                i+=1
    print(i)

    #assocy predict rt
    data=[]

    #predict detectability (optional)

    #build image
    total_ms1_mz = ms1_end_mz - ms1_start_mz
    n_bin_ms1 = int(total_ms1_mz // bin_mz)
    im = np.zeros([max_cycle, n_bin_ms1])
    max_rt = np.max(rt_pred)
    min_rt = np.min(rt_pred)
    total_rt = max_rt - min_rt
    for (rt,mz_ratio) in data :
        im[int((rt-min_rt/total_rt)*max_cycle),int(((mz_ratio-ms1_start_mz)/total_ms1_mz)*n_bin_ms1)]=1

    return im


def build_ref_image_from_diann(path_parqet, ms1_end_mz, ms1_start_mz, bin_mz, max_cycle, min_rt=None, max_rt=None):


    df = load_lib(path_parqet)
    df=df[['Stripped.Sequence','Precursor.Charge','RT','Precursor.Mz']]
    df_unique = df.drop_duplicates()
    #build image
    total_ms1_mz = ms1_end_mz - ms1_start_mz
    n_bin_ms1 = int(total_ms1_mz // bin_mz)
    im = np.zeros([max_cycle, n_bin_ms1])
    if max_rt is None:
        max_rt = np.max(df_unique['RT'])
    if min_rt is None:
        min_rt = np.min(df_unique['RT'])
    total_rt = max_rt - min_rt +1e-3
    for row in df_unique.iterrows() :
        if 900 > int(((row[1]['Precursor.Mz']-ms1_start_mz)/total_ms1_mz)*n_bin_ms1) >= 0:
            im[int((row[1]['RT']-min_rt)/total_rt*max_cycle),int(((row[1]['Precursor.Mz']-ms1_start_mz)/total_ms1_mz)*n_bin_ms1)]=1

    return im



if __name__ == '__main__':
    # fasta_similarity('fasta/uniparc_proteome_UP000033376_2025_03_14.fasta','fasta/uniparc_proteome_UP000033499_2025_03_14.fasta')
    # im = build_ref_image_from_diann('fasta/steigerwaltii variants/uniparc_proteome_UP000033376_2025_03_14.predicted.parquet', ms1_end_mz=1250, ms1_start_mz=350, bin_mz=1, max_cycle=663, rt_pred=[])
    # plt.clf()
    # mpimg.imsave('test_img.png', im)

    df = build_database_ref_peptide()
    df_full = load_lib('fasta/full proteom/steigerwaltii variants/uniparc_proteome_UP000033376_2025_03_14.predicted.parquet')
    min_rt = df_full['RT'].min()
    max_rt = df_full['RT'].max()
    for spe in ['Proteus mirabilis','Klebsiella pneumoniae','Klebsiella oxytoca','Enterobacter hormaechei','Citrobacter freundii']:
        im = build_ref_image_from_diann(
            'fasta/optimal peptide set/'+spe+'.parquet', ms1_end_mz=1250,
            ms1_start_mz=350, bin_mz=1, max_cycle=663, min_rt=min_rt, max_rt=max_rt)
        plt.clf()
        mpimg.imsave(spe+'.png', im)