import numpy as np
import pandas as pd
from loess.loess_1d import loess_1d
import scipy as sp
from sklearn.metrics import r2_score
from sympy.abc import alpha

import dataloader
from dataloader import RT_Dataset
from msms_processing import load_data
import matplotlib.pyplot as plt


ALPHABET_UNMOD = {
    "": 0,
    "A": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "K": 9,
    "L": 10,
    "M": 11,
    "N": 12,
    "P": 13,
    "Q": 14,
    "R": 15,
    "S": 16,
    "T": 17,
    "V": 18,
    "W": 19,
    "Y": 20,
    "CaC": 21,
    "OxM": 22
}

ALPHABET_UNMOD_REV = {v: k for k, v in ALPHABET_UNMOD.items()}

def numerical_to_alphabetical(arr):
    seq = ''
    for i in range(len(arr)):
        seq+=ALPHABET_UNMOD_REV[arr[i]]
    return seq

def numerical_to_alphabetical_str(s):
    seq = ''
    s = s.replace('[','')
    s = s.replace(']', '')
    arr = s.split(',')
    arr = list(map(int, arr))
    for i in range(len(arr)):
        seq+=ALPHABET_UNMOD_REV[arr[i]]
    return seq

def align(dataset, reference):
    seq_ref = reference['Sequence']
    seq_common = dataset['Sequence']
    seq_ref = seq_ref.tolist()
    seq_common = seq_common.tolist()

    seq_ref = [tuple(l) for l in seq_ref]
    seq_common = [tuple(l) for l in seq_common]

    ind_dict_ref = dict((k, i) for i, k in enumerate(seq_ref))
    inter = set(ind_dict_ref).intersection(seq_common)
    ind_dict_ref = [ind_dict_ref[x] for x in inter]

    indices_common = dict((k, i) for i, k in enumerate(seq_common))
    indices_common = [indices_common[x] for x in inter]

    rt_ref = reference['Retention time'][ind_dict_ref].reset_index()
    rt_data = dataset['Retention time'][indices_common].reset_index()

    xout, yout, wout = loess_1d(np.array(rt_data['Retention time'].tolist()), np.array(rt_ref['Retention time'].tolist()),
                                xnew=dataset['Retention time'],
                                degree=1, frac=0.5,
                                npoints=None, rotate=False, sigy=None)
    dataset['Retention time'] = yout
    return dataset

def filter_cysteine(df, col):
    def map_cys(str):
        return not('C' in str)
    df['cys'] = df[col].map(map_cys)
    data = df[df['cys']].reset_index(drop=True)
    return data

def compare_include_df(df, sub_df, save = True, path = 'temp.png'):
    df_value_list = []
    df_sub_value_list=[]
    for r in sub_df.iterrows() :
        try :
            df_value_list.append(df[df['Sequence']==r[1]['Sequence']]['Retention time'].reset_index(drop=True)[0])
            df_sub_value_list.append(r[1]['Retention time'])
        except:
            pass
    plt.clf()
    fig, ax = plt.subplots()
    ax.scatter(df_sub_value_list, df_value_list)
    x = np.array([min(df_value_list), max(df_value_list)])
    linreg = sp.stats.linregress(df_value_list, df_sub_value_list)
    ax.annotate("r-squared = {:.3f}".format(r2_score(df_value_list, df_sub_value_list)), (0, 1))
    plt.plot(x, linreg.intercept + linreg.slope * x, 'r')

    if save :
        plt.savefig(path)
    plt.clf()

    return df_value_list, df_sub_value_list

# data_ori = load_data('msms/msms30_01.txt').reset_index(drop=True)
# # data_ori['sequence'] = data_ori['sequence'].map(numerical_to_alphabetical)
#
# data_train = load_data('msms/msms16_01.txt').reset_index(drop=True)
# # data_train = pd.read_pickle('database/data_DIA_16_01.pkl').reset_index(drop=True)
# data_align = align(data_train, data_ori)
# data_align.to_pickle('database/data_DIA_16_01_aligned30_01.pkl')
#
# data_train = load_data('msms/msms17_01.txt').reset_index(drop=True)
# # data_train = pd.read_pickle('database/data_DIA_17_01.pkl').reset_index(drop=True)
# data_align = align(data_train, data_ori)
# data_align.to_pickle('database/data_DIA_17_01_aligned30_01.pkl')
#
# data_train = load_data('msms/msms20_01.txt').reset_index(drop=True)
# # data_train = pd.read_pickle('database/data_DIA_20_01.pkl').reset_index(drop=True)
# data_align = align(data_train, data_ori)
# data_align.to_pickle('database/data_DIA_20_01_aligned30_01.pkl')
#
# data_train = load_data('msms/msms23_01.txt').reset_index(drop=True)
# # data_train = pd.read_pickle('database/data_DIA_23_01.pkl').reset_index(drop=True)
# data_align = align(data_train, data_ori)
# data_align.to_pickle('database/data_DIA_23_01_aligned30_01.pkl')
#
# data_train = load_data('msms/msms24_01.txt').reset_index(drop=True)
# # data_train = pd.read_pickle('database/data_DIA_24_01.pkl').reset_index(drop=True)
# data_align = align(data_train, data_ori)
# data_align.to_pickle('database/data_DIA_24_01_aligned30_01.pkl')
#
# data_train = load_data('msms/msms30_01.txt').reset_index(drop=True)
# data_train = pd.read_pickle('database/data_DIA_30_01.pkl').reset_index(drop=True)
# # data_align = align(data_train, data_ori)
# data_train.to_pickle('database/data_DIA_30_01_aligned30_01.pkl')
#
# plt.scatter(data_train['Retention time'], data_align['Retention time'], s=1)
# plt.savefig('test_align_2.png')
#
# dataset_ref = pd.read_pickle('database/data_01_16_DIA_ISA_55.pkl')
# data_ref = Common_Dataset(dataset_ref, 25).data
# dataset_2 = pd.read_pickle('database/data_01_20_DIA_ISA_55.pkl')
# data_2 = Common_Dataset(dataset_2, 25).data
# dataset_3 = pd.read_pickle('database/data_01_17_DIA_ISA_55.pkl')
# data_3 = Common_Dataset(dataset_3, 25).data
# dataset_4 = pd.read_pickle('database/data_01_23_DIA_ISA_55.pkl')
# data_4 = Common_Dataset(dataset_4, 25).data
# data_align_3 = align(data_3, data_ref)
# data_align_4 = align(data_4, data_ref)
#
# data = pd.concat([data_ref, data_2, data_align_3, data_align_4], ignore_index=True)
# data = data.drop(columns='index')
# data['Sequence'] = data['Sequence'].map(numerical_to_alphabetical)
# num_data = data.shape[0]
# train_num = np.floor(num_data*0.8)
# train_size=0
# list_train=[]
# list_test=[]
# groups = data.groupby('Sequence')
# for seq, gr in groups:
#
#     train_size+= gr.shape[0]
#
#     if train_size>train_num:
#         list_test.append(gr)
#     else:
#         list_train.append(gr)
#
#
# dataset_train = pd.concat(list_train, ignore_index=True)
# dataset_test = pd.concat(list_test, ignore_index=True)
# dataset_train.to_pickle('database/data_DIA_ISA_55_train.pkl')
# dataset_train.to_pickle('database/data_DIA_ISA_55_test.pkl')

# data_train_1 = pd.read_pickle('database/data_DIA_ISA_55_test_30_01.pkl').reset_index(drop=True)
# data_train_2 = pd.read_pickle('database/data_DIA_ISA_55_train_30_01.pkl').reset_index(drop=True)
# data_ori = pd.read_csv('database/data_train.csv').reset_index(drop=True)
# data_ori['Sequence']=data_ori['sequence']
# data_ori['Retention time']=data_ori['irt']
# data_train = pd.concat([data_train_2,data_train_1]).reset_index(drop=True)
# data_align = align(data_train, data_ori)
#
# data_align.to_pickle('database/data_ISA_dual_align.pkl')


#compare DIANN pred to DIA mesures
# df_ori = pd.read_csv('database/data_train.csv')
# df_ori['Sequence']=df_ori['sequence']
# df_ori['Retention time']=df_ori['irt']
# df_diann = pd.read_csv('database/CIT_BASE_UP000584719_546.csv')
#
# df_ISA = pd.read_pickle('database/data_ISA_dual_align.pkl')
#
#
#
# df_diann_aligned = align(df_diann, df_ori)
#
# df_value_list, df_sub_value_list = compare_include_df(df_diann_aligned, df_ISA, True)




#create augmented dataset from ISA data + column invariant prosit peptides
# df_base = pd.read_pickle('database/data_DIA_ISA_55_train.pkl')
# df_base = df_base[['Sequence','Retention time']]
#
# df_1 = pd.read_pickle('database/data_prosit_threshold_1.pkl')
# df_1['Sequence']= df_1['Sequence'].map(numerical_to_alphabetical_str)
#
# df_2 = pd.read_pickle('database/data_prosit_threshold_2.pkl')
# df_2['Sequence']= df_2['Sequence'].map(numerical_to_alphabetical_str)
#
# df_3 = pd.read_pickle('database/data_prosit_threshold_3.pkl')
# df_3['Sequence']= df_3['Sequence'].map(numerical_to_alphabetical_str)
#
# df_augmented_1 = pd.concat([df_1,df_base],axis=0).reset_index(drop=True)
# df_augmented_1.columns=['sequence','irt']
# df_augmented_1['state']='train'
# df_augmented_1.to_csv('database/data_ISA_augmented_1.csv')
#
# df_augmented_2 = pd.concat([df_2,df_base],axis=0).reset_index(drop=True)
# df_augmented_2.columns=['sequence','irt']
# df_augmented_2['state']='train'
# df_augmented_2.to_csv('database/data_ISA_augmented_2.csv')
#
# df_augmented_3 = pd.concat([df_3,df_base],axis=0).reset_index(drop=True)
# df_augmented_3.columns=['sequence','irt']
# df_augmented_3['state']='train'
# df_augmented_3.to_csv('database/data_ISA_augmented_3.csv')


#testing intersection between test and augmented dataset
# df_sup = pd.read_pickle('database/data_prosit_threshold_3.pkl')
# df_test = pd.read_pickle('database/data_DIA_ISA_55_test.pkl')
#
# inter = []
# n = 0
# df_sup['Sequence']= df_sup['Sequence'].map(numerical_to_alphabetical_str)
# groups_sup = df_sup.groupby('Sequence')
#
# groups_test = df_test.groupby('Sequence')
# for seq, _ in groups_sup:
#     for seq2, _ in groups_test:
#         if seq2==seq :
#             inter.append(seq)