import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from loess.loess_1d import loess_1d
import time

ALPHABET_UNMOD = {
    "": 0,
    "A": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "K": 9,
    "L": 10,
    "M": 11,
    "N": 12,
    "P": 13,
    "Q": 14,
    "R": 15,
    "S": 16,
    "T": 17,
    "V": 18,
    "W": 19,
    "Y": 20,
    "OxM": 21,
    "CaC": 22
}

ALPHABET_UNMOD_REV = {v: k for k, v in ALPHABET_UNMOD.items()}


def align(dataset, reference, column_dataset, column_ref, seq_data, seq_ref):
    dataset_ref=dataset[dataset['state']=='train']
    dataset_unique = dataset_ref[[seq_data,column_dataset]].groupby(seq_data).mean()
    print('unique',len(dataset_unique))
    reference_unique = reference[[seq_ref,column_ref]].groupby(seq_ref).mean()
    seq_ref = reference_unique.index
    seq_common = dataset_unique.index
    seq_ref = seq_ref.tolist()
    seq_common = seq_common.tolist()

    seq_ref = [tuple(l) for l in seq_ref]
    seq_common = [tuple(l) for l in seq_common]

    ind_dict_ref = dict((k, i) for i, k in enumerate(seq_ref))
    inter = set(ind_dict_ref).intersection(seq_common)
    print(len(inter))

    ind_dict_ref = [ind_dict_ref[x] for x in inter]

    indices_common = dict((k, i) for i, k in enumerate(seq_common))
    indices_common = [indices_common[x] for x in inter]


    rt_ref = reference_unique[column_ref][ind_dict_ref].reset_index()
    rt_data = dataset_unique[column_dataset][indices_common].reset_index()

    plt.scatter(rt_data[column_dataset].tolist(),rt_ref[column_ref].tolist(),s=0.1)
    plt.savefig('test.png')

    xout, yout, wout = loess_1d(np.array(rt_data[column_dataset].tolist()), np.array(rt_ref[column_ref].tolist()),
                                xnew=dataset[column_dataset],
                                degree=1, frac=0.25,
                                npoints=None, rotate=False, sigy=None)

    plt.scatter(xout, yout, s=0.1)
    plt.savefig('test_2.png')

    dataset[column_dataset] = yout
    return dataset

def get_number_unique_peptide(dataset):
    seq = dataset['sequence']
    a = seq.unique()
    return len(a)

def compare_error(df1, df2, display=False, save=False, path=None):
    df1['abs err 1'] = df1['rt pred'] - df1['true rt']
    df2['abs err 2'] = df2['rt pred'] - df2['true rt']
    df_group_1 = df1.groupby(['seq'])['abs err 1'].mean().to_frame().reset_index()
    df_group_2 = df2.groupby(['seq'])['abs err 2'].mean().to_frame().reset_index()
    df = pd.concat([df_group_1,df_group_2],axis=1)

    fig, ax = plt.subplots()
    ax.scatter(df['abs err 1'], df['abs err 2'], s=0.1, alpha=0.05)

    plt.savefig('temp.png')


    if display:
        plt.show()

    if save:
        plt.savefig(path)

def select_best_data(df_list,threshold):
    num = len(df_list)
    l=[]
    i=0
    for df in df_list :
        df['abs err {}'.format(i)] = abs(df['rt pred'] - df['true rt'])
        df_group = df.groupby(['seq'])['abs err {}'.format(i)].mean().to_frame().reset_index()
        l.append(df_group)
        i += 1
    df = pd.concat(l, axis=1)
    df['mean'] = df['abs err 0']
    for i in range(1,num):
        df['mean']=df['mean']+df['abs err {}'.format(i)]
    df['mean'] = df['mean']/num
    df_res = df[df['mean']<threshold]
    c_name=['seq']+['seq{}'.format(i) for i in range(1,num)]+['mean']
    df_res = df_res[['seq','mean']]
    df_res.columns = c_name
    df_res = df_res[['seq', 'mean']]
    df_merged = df_list[0].merge(df_res, how='inner', on='seq')
    df_merged = df_merged[['seq','true rt']]
    df_merged.columns = ['sequence','irt_scaled']
    return df_merged

def add_length(dataframe):
    def fonc(a):
        a = a.replace('[', '')
        a = a.replace(']', '')
        a = a.split(',')
        a = list(map(int, a))
        return np.count_nonzero(np.array(a))
    dataframe['length']=dataframe['seq'].map(fonc)

def add_length_alpha(dataframe):
    dataframe['length']=dataframe['sequence'].map(lambda x : len(x))

def numerical_to_alphabetical_str(s):
    seq = ''
    s = s.replace('[','')
    s = s.replace(']', '')
    arr = s.split(',')
    arr = list(map(int, arr))
    for i in range(len(arr)):
        seq+=ALPHABET_UNMOD_REV[arr[i]]
    return seq

# def main():
#     ref = pd.read_csv('data_prosit/data.csv')
#     df_ISA = pd.read_csv('data_PXD006109/e_coli/data_coli.csv')
#     df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled','sequence', 'mod_sequence')
#     df_ISA_aligned.to_csv('data_PXD006109/e_coli/data_aligned_train_coli.csv', index=False)




if __name__ == '__main__':
    # main()

    # df_base = pd.read_csv('./data_PXD006109/plasma_train/data_aligned_train_plasma.csv')
    # df_base = df_base[['sequence', 'irt_scaled','state']]
    # t = [0.05,0.1,0.2,0.3,0.4,0.5,0.7,1,10]
    # name = ['005','01','02','03','04','05','07','1','all']
    # df_0 = pd.read_csv('../output/out_plasma_aligned_train_0.csv')
    # df_1 = pd.read_csv('../output/out_plasma_aligned_train_1.csv')
    # df_2 = pd.read_csv('../output/out_plasma_aligned_train_2.csv')
    # df_3 = pd.read_csv('../output/out_plasma_aligned_train_3.csv')
    # df_4 = pd.read_csv('../output/out_plasma_aligned_train_4.csv')
    #
    # list_df = [df_0, df_1, df_2, df_3, df_4]
    # for i in range(len(name)):
    #     #creating augmented datasets
    #     print('thresold {} en cours'.format(name[i]))
    #     #
    #     df = select_best_data(list_df, t[i])
    #     df.to_pickle('./data_PXD006109/plasma_train/data_ISA_additionnal_{}.pkl'.format(name[i]))
    #     df = pd.read_pickle('./data_PXD006109/plasma_train/data_ISA_additionnal_{}.pkl'.format(name[i]))
    #     df['state'] = 'train'
    #     df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
    #     df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
    #     df_augmented_1.columns = ['sequence', 'irt_scaled','state']
    #
    #     df_augmented_1.to_csv('./data_PXD006109/plasma_train/plasma_train_data_augmented_{}.csv'.format(name[i]), index=False)
    #     print(df_augmented_1.shape)

    df = pd.read_csv('data_prosit/data_noc.csv')
    add_length_alpha(df)
    df_long = df[df['length']>=10]
    df_long.to_csv('data_prosit/data_noc_long.csv')
    #test