import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

ALPHABET_UNMOD = {
    "_": 0,
    "A": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "K": 9,
    "L": 10,
    "M": 11,
    "N": 12,
    "P": 13,
    "Q": 14,
    "R": 15,
    "S": 16,
    "T": 17,
    "V": 18,
    "W": 19,
    "Y": 20,
    "CaC": 21,
    "OxM": 22
}

IUPAC_VOCAB = {
    "_": 0,
    "<mask>": 1,
    "<cls>": 2,
    "<sep>": 3,
    "<unk>": 4,
    "A": 5,
    "B": 6,
    "C": 7,
    "D": 8,
    "E": 9,
    "F": 10,
    "G": 11,
    "H": 12,
    "I": 13,
    "K": 14,
    "L": 15,
    "M": 16,
    "N": 17,
    "O": 18,
    "P": 19,
    "Q": 20,
    "R": 21,
    "S": 22,
    "T": 23,
    "U": 24,
    "V": 25,
    "W": 26,
    "X": 27,
    "Y": 28,
    "Z": 29}

ALPHABET_UNMOD_REV = {v: k for k, v in ALPHABET_UNMOD.items()}


def padding(dataframe, columns, length):
    def pad(x):
        return x + (length - len(x) + 2 * x.count('-')) * '_'

    for i in range(len(dataframe)):
        if len(dataframe[columns][i]) > length + 2 * dataframe[columns][i].count('-'):
            dataframe.drop(i)
    dataframe[columns] = dataframe[columns].map(pad)
    for i in range(len(dataframe)):
        if len(dataframe[columns][i]) > length:
            dataframe.drop(i)


def alphabetical_to_numerical(seq, vocab):
    num = []
    dec = 0
    if vocab == 'unmod':
        for i in range(len(seq) - 2 * seq.count('-')):
            if seq[i + dec] != '-':
                num.append(ALPHABET_UNMOD[seq[i + dec]])
            else:
                if seq[i + dec + 1:i + dec + 4] == 'CaC':
                    num.append(21)
                elif seq[i + dec + 1:i + dec + 4] == 'OxM':
                    num.append(22)
                else:
                    raise 'Modification not supported'
                dec += 4
    else :
        for i in range(len(seq) - 2 * seq.count('-')):
            if seq[i + dec] != '-':
                num.append(ALPHABET_UNMOD[seq[i + dec]])
            else:
                if seq[i + dec + 1:i + dec + 4] == 'CaC':
                    num.append(21)
                elif seq[i + dec + 1:i + dec + 4] == 'OxM':
                    num.append(22)
                else:
                    raise 'Modification not supported'
                dec += 4
    return np.array(num)

def numerical_to_alphabetical(arr):
    seq = ''
    for i in range(len(arr)):
        seq+=ALPHABET_UNMOD_REV[arr[i]]
    return seq

def zero_to_minus(arr):
    arr[arr <= 0.00001] = -1.
    return arr


class Common_Dataset(Dataset):

    def __init__(self, dataframe, length, pad=True, convert=True, vocab='unmod', file=False):
        print('Data loader Initialisation')
        self.data = dataframe.reset_index()
        self.file_mode = file
        if pad :
            print('Padding')
            padding(self.data, 'Sequence', length)

        if convert :
            print('Converting')
            self.data['Sequence'] = self.data['Sequence'].map(lambda x: alphabetical_to_numerical(x, vocab))
            self.data['Spectra'] = self.data['Spectra'].map(zero_to_minus)

    def __getitem__(self, index: int):
        seq = self.data['Sequence'][index]
        rt = self.data['Retention time'][index]
        intensity = self.data['Spectra'][index]
        charge = self.data['Charge'][index]
        file = self.data['file'][index]

        if self.file_mode :
            return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity),  torch.tensor(file)
        else :
            return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity)

    def set_file_mode(self,b):
        self.file_mode=b

    def __len__(self) -> int:
        return self.data.shape[0]


def load_data(path_train, path_val, path_test, batch_size, length, pad=False, convert=False, vocab = 'unmod'):
    print('Loading data')
    data_train = pd.read_pickle(path_train)
    data_val = pd.read_pickle(path_val)
    data_test = pd.read_pickle(path_test)
    train = Common_Dataset(data_train, length, pad, convert, vocab)
    test = Common_Dataset(data_val, length, pad, convert, vocab)
    val = Common_Dataset(data_test, length, pad, convert, vocab)
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)

    return train_loader, val_loader, test_loader