import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import pyopenms as oms

def plot_spectra_2d(exp, ms_level=1, marker_size=5, out_path='temp.png'):
    exp.updateRanges()
    for spec in exp:
        if spec.getMSLevel() == ms_level:
            mz, intensity = spec.get_peaks()
            p = intensity.argsort()  # sort by intensity to plot highest on top
            rt = np.full([mz.shape[0]], spec.getRT(), float)
            plt.scatter(
                rt,
                mz[p],
                c=intensity[p],
                cmap="afmhot_r",
                s=marker_size,
                norm=colors.LogNorm(
                    exp.getMinIntensity() + 1, exp.getMaxIntensity()
                ),
            )
    plt.clim(exp.getMinIntensity() + 1, exp.getMaxIntensity())
    plt.xlabel("time (s)")
    plt.ylabel("m/z")
    plt.colorbar()
    plt.savefig(out_path)  # slow for larger data sets


def build_image_ms1(path, bin_mz):
    e = oms.MSExperiment()
    oms.MzMLFile().load(path, e)
    e.updateRanges()
    id = e.getSpectra()[-1].getNativeID()

    dico = dict(s.split('=', 1) for s in id.split())
    max_cycle = int(dico['cycle'])
    list_cycle = [[] for _ in range(max_cycle)]

    for s in e:
        if s.getMSLevel() == 1:
            ms1_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
            ms1_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
            break

    total_ms1_mz = ms1_end_mz - ms1_start_mz
    n_bin_ms1 = int(total_ms1_mz//bin_mz)
    size_bin_ms1 = total_ms1_mz / n_bin_ms1
    for spec in e:  # data structure
        id = spec.getNativeID()
        dico = dict(s.split('=', 1) for s in id.split())
        if spec.getMSLevel() == 1:
            list_cycle[int(dico['cycle']) - 1].insert(0, spec)

    im = np.zeros([max_cycle, n_bin_ms1])

    for c in range(max_cycle):  # Build one cycle image
        line = np.zeros(n_bin_ms1)
        if len(list_cycle[c]) > 0:
            for k in range(len(list_cycle[c])):
                ms1 = list_cycle[c][k]
                intensity = ms1.get_peaks()[1]
                mz = ms1.get_peaks()[0]
                id = ms1.getNativeID()
                dico = dict(s.split('=', 1) for s in id.split())
                for i in range(ms1.size()):
                    line[int((mz[i] - ms1_start_mz) // size_bin_ms1)] += intensity[i]

        im[c, :] = line

    return im
