import fastapy
import matplotlib.pyplot as plt
import numpy as np
from matplotlib_venn import venn2




ALPHABET_UNMOD = {
    "A": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "K": 9,
    "L": 10,
    "M": 11,
    "N": 12,
    "P": 13,
    "Q": 14,
    "R": 15,
    "S": 16,
    "T": 17,
    "V": 18,
    "W": 19,
    "Y": 20,
}


MASSES_MONO = {
    "A": 71.03711,
    "C": 103.00919,
    "D": 115.02694,
    "E": 129.04259,
    "F": 147.06841,
    "G": 57.02146,
    "H": 137.05891,
    "I": 113.08406,
    "K": 128.09496,
    "L": 113.08406,
    "M": 131.04049,
    "N": 114.04293,
    "P": 97.05276,
    "Q": 128.05858,
    "R": 156.1875,
    "S": 87.03203,
    "T": 101.04768,
    "V": 99.06841,
    "W": 186.07931,
    "Y": 163.06333,
}


MASSES_AVG = {
    "A": 71.0788,
    "C": 103.1388,
    "D": 115.0886,
    "E": 129.1155,
    "F": 147.1766,
    "G": 57.0519,
    "H": 137.1411,
    "I": 113.1594,
    "K": 128.1741,
    "L": 113.1594,
    "M": 131.1926,
    "N": 114.1038,
    "P": 97.1167,
    "Q": 128.1307,
    "R": 156.1875,
    "S": 87.0782,
    "T": 101.1051,
    "V": 99.1326,
    "W": 186.2132,
    "Y": 163.1760,
}

# trypsin cut after K or R (if not followed by P)

def cut(seq):
    cuts = []
    l = len(seq)
    for i in range(l):
        if seq[i] == 'R' or seq[i] == 'K':
            if i < l - 1 and seq[i + 1] != 'P':
                cuts.append(i + 1)
    return cuts


def cut_with_ind(seq, ind_list):
    l = []
    size = len(seq)
    ind_list.append(size)

    for i in range(len(ind_list) - 1):
        if i == 0:
            l.append(seq[:ind_list[i]])
        l.append(seq[ind_list[i]:ind_list[i + 1]])

    return l

def digest(seq):
    ind = cut(seq)
    return cut_with_ind(seq, ind)

def fasta_similarity(path_fasta_1, path_fasta_2):
    list_seq_1=[]
    list_seq_2 = []
    for record in fastapy.parse(path_fasta_1):
        list_seq_1.append(record.seq)
    for record in fastapy.parse(path_fasta_2):
        list_seq_2.append(record.seq)

    set1 = set(list_seq_1)
    set2 = set(list_seq_2)

    venn2((set1, set2), ('Group1', 'Group2'))
    plt.show()
    plt.savefig('fasta_similarity.png')

def compute_mass(seq, isotop):
    m = 0
    if isotop == 'mono':
        for char in MASSES_MONO.keys():
            m += MASSES_MONO[char] * seq.count(char)
    if isotop == 'avg':
        for char in MASSES_AVG.keys():
            m += MASSES_AVG[char] * seq.count(char)
    return m

def build_ref_image(path_fasta, possible_charge, ms1_end_mz, ms1_start_mz, bin_mz, max_cycle, rt_pred):

    #build peptide list
    list_seq_1 = []
    for record in fastapy.parse(path_fasta):
        list_seq_1.append(record.seq)

    list_peptides = []
    for prot in list_seq_1 :
        list_peptides.extend(digest(prot))

    #compute m/z ration
    mz_ratio={}
    for seq in list_peptides:
        mz_ratio['seq']=[]
        for charge in possible_charge:
            ratio = compute_mass(seq,'avg')/charge
            if ms1_end_mz > ratio > ms1_start_mz:
                mz_ratio['seq'].append(ratio)

    #assocy predict rt
    data=[]

    #predict detectability (optional)

    #build image
    total_ms1_mz = ms1_end_mz - ms1_start_mz
    n_bin_ms1 = int(total_ms1_mz // bin_mz)
    im = np.zeros([max_cycle, n_bin_ms1])
    max_rt = np.max(rt_pred)
    min_rt = np.min(rt_pred)
    total_rt = max_rt - min_rt
    for (rt,mz_ratio) in data :
        im[int((rt-min_rt/total_rt)*max_cycle),int(((mz_ratio-ms1_start_mz)/total_ms1_mz)*n_bin_ms1)]=1

    return im

if __name__ == '__main__':
    # fasta_similarity('fasta/uniprotkb_proteome_UP000742934_2025_03_12.fasta','fasta/uniprotkb_proteome_UP001182277_2025_03_12.fasta')
    # mass = build_ref_image('fasta/uniprotkb_proteome_UP000742934_2025_03_12.fasta')
    pass