import pyopenms as oms
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from PIL import Image

def plot_spectra_2d(exp, ms_level=1, marker_size=5):
    exp.updateRanges()
    for spec in exp:
        if spec.getMSLevel() == ms_level:
            mz, intensity = spec.get_peaks()
            p = intensity.argsort()  # sort by intensity to plot highest on top
            rt = np.full([mz.shape[0]], spec.getRT(), float)
            plt.scatter(
                rt,
                mz[p],
                c=intensity[p],
                cmap="afmhot_r",
                s=marker_size,
                norm=colors.LogNorm(
                    exp.getMinIntensity() + 1, exp.getMaxIntensity()
                ),
            )
    plt.clim(exp.getMinIntensity() + 1, exp.getMaxIntensity())
    plt.xlabel("time (s)")
    plt.ylabel("m/z")
    plt.colorbar()
    plt.show()  # slow for larger data sets


def count_data_points(exp):
    s = exp.getSpectra()
    c = 0
    for i in range(len(s)):
        c += s[i].size()
    return c


def reconstruct_spectra(exp, ind):
    a1 = exp.getChromatograms()[1]
    ref = exp.getSpectrum(ind)
    rt1 = ref.getRT()
    rt2 = exp.getSpectrum(ind + 1).getRT()
    peaks = a1.get_peaks()
    data = peaks[1][(rt1 <= peaks[0]) & (peaks[0] <= rt2)]
    return data, ref.get_peaks()

def build_image(e, bin_mz):
    e.updateRanges()
    id = e.getSpectra()[-1].getNativeID()
    dico = dict(s.split('=', 1) for s in id.split())
    max_cycle = int(dico['cycle'])
    list_cycle = [[] for _ in range(max_cycle)]

    for s in e:
        if s.getMSLevel() == 2:
            ms2_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
            ms2_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
            break

    for s in e:
        if s.getMSLevel() == 1:
            ms1_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
            ms1_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
            break

    total_ms2_mz = ms2_end_mz - ms2_start_mz
    n_bin_ms2 = int(total_ms2_mz // bin_mz) + 1
    size_bin_ms2 = total_ms2_mz / n_bin_ms2

    total_ms1_mz = ms1_end_mz - ms1_start_mz
    n_bin_ms1 = 100  # pour l'instant
    size_bin_ms1 = total_ms1_mz / n_bin_ms1

    for spec in e:  # data structure
        id = spec.getNativeID()
        dico = dict(s.split('=', 1) for s in id.split())
        if spec.getMSLevel() == 2:
            list_cycle[int(dico['cycle'])-1].append(spec)
        if spec.getMSLevel() == 1:
            list_cycle[int(dico['cycle'])-1].insert(0, spec)

    im = np.zeros([max_cycle, 100, n_bin_ms2 + 1])
    for c in range(max_cycle):  # Build one cycle image
        j=0
        chan = np.zeros([n_bin_ms1, n_bin_ms2 + 1])
        if len(list_cycle[c])>0 :
            if list_cycle[c][0].getMSLevel() == 1:
                j=1
                ms1 = list_cycle[c][0]
                intensity = ms1.get_peaks()[1]
                mz = ms1.get_peaks()[0]
                for i in range(ms1.size()):
                    chan[int((mz[i]-ms1_start_mz) // size_bin_ms1), 0] += intensity[i]

        for k in range(j, len(list_cycle[c])):
            ms2 = list_cycle[c][k]
            intensity = ms2.get_peaks()[1]
            mz = ms2.get_peaks()[0]
            id = ms2.getNativeID()
            dico = dict(s.split('=', 1) for s in id.split())
            for i in range(ms2.size()):
                chan[int(dico['experiment'])-2, int((mz[i]-ms2_start_mz) // size_bin_ms2)] += intensity[i]

        im[c, :, :] = chan

    return im

def build_image_generic(e, bin_mz, num_RT):
    e.updateRanges()
    list_cycle = []

    for s in e:
        if s.getMSLevel() == 2:
            ms2_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
            ms2_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
            break

    for s in e:
        if s.getMSLevel() == 1:
            ms1_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
            ms1_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
            break

    total_ms2_mz = ms2_end_mz - ms2_start_mz + 10
    n_bin_ms2 = int(total_ms2_mz // bin_mz) + 1
    size_bin_ms2 = total_ms2_mz / n_bin_ms2

    total_ms1_mz = ms1_end_mz - ms1_start_mz + 10
    n_bin_ms1 = 100  # pour l'instant
    size_bin_ms1 = total_ms1_mz / n_bin_ms1

    cycle = -1
    for spec in e:  # data structure
        if spec.getMSLevel() == 1:
            cycle += 1
            list_cycle.append([])
            list_cycle[cycle].insert(0, spec)
        if spec.getMSLevel() == 2:
            try :
                list_cycle[cycle].append(spec)
            except :
                list_cycle.append([])
                list_cycle[cycle].append(spec)
    max_cycle = len(list_cycle)
    total_by_window = max_cycle//num_RT + 1
    experiment_max = len(list_cycle[-2])-1
    im = np.zeros([max_cycle, experiment_max, n_bin_ms2 + 1])
    for c in range(0, max_cycle, num_RT):  # Build one cycle image
        j=0
        experiment = 0
        chan = np.zeros([experiment_max, n_bin_ms2 + 1])
        if len(list_cycle[c])>0 :
            if list_cycle[c][0].getMSLevel() == 1:
                j=1
                pass

        for k in range(j, len(list_cycle[c])):
            for n in range(num_RT):
                ms2 = list_cycle[c][k+n]
                intensity = ms2.get_peaks()[1]
                mz = ms2.get_peaks()[0]

                for i in range(ms2.size()):
                    chan[experiment, int((mz[i]-ms2_start_mz) // size_bin_ms2)] += intensity[i]
                experiment +=1
        im[c, :, :] = chan

    return im


def build_image_frag(e, bin_mz):
    e.updateRanges()
    id = e.getSpectra()[-1].getNativeID()

    dico = dict(s.split('=', 1) for s in id.split())
    max_cycle = int(dico['cycle'])
    list_cycle = [[] for _ in range(max_cycle)]

    for s in e:
        if s.getMSLevel() == 2:
            ms2_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
            ms2_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
            break

    for s in e:
        if s.getMSLevel() == 1:
            ms1_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
            ms1_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
            break

    total_ms2_mz = ms2_end_mz - ms2_start_mz
    n_bin_ms2 = int(total_ms2_mz // bin_mz) + 1
    size_bin_ms2 = total_ms2_mz / n_bin_ms2

    total_ms1_mz = ms1_end_mz - ms1_start_mz
    n_bin_ms1 = 100  # pour l'instant
    size_bin_ms1 = total_ms1_mz / n_bin_ms1
    for spec in e:  # data structure
        id = spec.getNativeID()
        dico = dict(s.split('=', 1) for s in id.split())
        if spec.getMSLevel() == 2:
            list_cycle[int(dico['cycle'])-1].append(spec)
        if spec.getMSLevel() == 1:
            list_cycle[int(dico['cycle'])-1].insert(0, spec)

    im = np.zeros([max_cycle, 100, n_bin_ms2])

    for c in range(max_cycle):  # Build one cycle image
        j=0
        chan = np.zeros([n_bin_ms1, n_bin_ms2])
        if len(list_cycle[c])>0 :
            if list_cycle[c][0].getMSLevel() == 1:
                j = 1
                pass
        for k in range(j, len(list_cycle[c])):
            ms2 = list_cycle[c][k]
            intensity = ms2.get_peaks()[1]
            mz = ms2.get_peaks()[0]
            id = ms2.getNativeID()
            dico = dict(s.split('=', 1) for s in id.split())
            for i in range(ms2.size()):
                chan[int(dico['experiment'])-2, int((mz[i]-ms2_start_mz) // size_bin_ms2)] += intensity[i]

        im[c, :, :] = chan

    return im


def check_windows(e):
    e.updateRanges()
    id = e.getSpectra()[-1].getNativeID()
    dico = dict(s.split('=', 1) for s in id.split())
    max_cycle = int(dico['cycle'])
    list_cycle = [[] for _ in range(max_cycle)]

    for s in e:
        if s.getMSLevel() == 2:
            ms2_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
            ms2_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
            break

    for s in e:
        if s.getMSLevel() == 1:
            ms1_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
            ms1_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
            break

    for spec in e:  # data structure
        id = spec.getNativeID()
        dico = dict(s.split('=', 1) for s in id.split())
        if spec.getMSLevel() == 2:
            list_cycle[int(dico['cycle'])-1].append(spec)
        if spec.getMSLevel() == 1:
            list_cycle[int(dico['cycle'])-1].insert(0, spec)

    res = []

    for c in range(max_cycle):
        res.append([])
        for k in range(0, len(list_cycle[c])):
            spec = list_cycle[c][k]
            if spec.getMSLevel() == 2:
                b = spec.getPrecursors()
                res[-1].append(b[0].getMZ() - b[0].getIsolationWindowLowerOffset())
                res[-1].append(b[0].getMZ() + b[0].getIsolationWindowUpperOffset())
    return res

def check_energy(im):
    len_RT = im.shape[0]
    len_frag = im.shape[1]
    len_3 = im.shape[2]
    l = np.zeros((len_RT,len_frag))
    for i in range(len_RT):
        for f in range(len_frag):
            frag = im[i,f,1:len_3].sum()
            prec = im[i,f,0]
            if prec != 0 :
                l[i,f]=frag/prec
    return l

if __name__ == "__main__":
    e = oms.MSExperiment()
    oms.MzMLFile().load("data/STAPH140.mzML", e)
    im = build_image_frag(e, 2)
    im2 = np.maximum(0,np.log(im+1))
    np.save('data/mz_image/Staph140.npy',im2)

    # norm = np.max(im2)
    # for i in range(im.shape[0]) :
    #     mat = im2[i, :, :]
    #     img = Image.fromarray(mat / norm)
    #     img.save('fig/mzimage/RT_frag_'+str(i)+'.tif')
    # res = check_windows(e)
    #
    # max_len = np.array([len(array) for array in res]).max()
    #
    # # What value do we want to fill it with?
    # default_value = 0
    #
    # b = [np.pad(array, (0, max_len - len(array)), mode='constant', constant_values=default_value) for array in res]