import torch.distributions as dist
import random
import numpy as np

ALPHABET_UNMOD = {
    "A": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "K": 9,
    "L": 10,
    "M": 11,
    "N": 12,
    "P": 13,
    "Q": 14,
    "R": 15,
    "S": 16,
    "T": 17,
    "V": 18,
    "W": 19,
    "Y": 20,
}

def reverse_protein(seq):
    return seq[::-1]

def reverse_peptide(seq, format='numerical'):
    if format == 'numerical' :
        for i in range(len(seq)//2):
            if seq[i]!=9 and  seq[i]!=15 :
                mem = seq[i]
                seq[i] = seq[-i]
                seq[-i] = mem
    if format=='alphabetical':
        for i in range(len(seq)//2):
            if seq[i]!='K' and  seq[i]!='R':
                mem = seq[i]
                seq[i] = seq[-i]
                seq[-i] = mem
    return seq

def shuffle_protein(seq):
    c = seq.copy()
    random.shuffle(c)
    return c

def shuffle_peptide(seq, format='numerical'): #TODO A reparer
    if format  == 'numerical':
        ind = np.where(seq == 9 or seq == 15, seq)
        print(ind)
        final_seq = seq.copy()
        random.shuffle(final_seq)
        del final_seq[ind]
    if format == 'alphabetical' :
        ind = np.where(seq == 'R' or seq == 'K', seq)
        print(ind)
        final_seq = seq.copy()
        del final_seq[ind]
        random.shuffle(final_seq)
    for i in range(len(ind)):
        final_seq.insert(ind[i]+i,seq[i])
    return final_seq
def random_aa(database, format='numerical'):
    total_seq = database.unroll()
    freq = total_seq.count()
    freq.normalize()
    d = dist.Categorical(freq)
    l = len(total_seq)
    new_seq = d.sample(l)
    #similarcutting


def random_aa_trypsin(database, format='numerical'):
    total_seq = database.unroll()
    if format=='numerical' :
        total_seq.remove(9)
        total_seq.remove(15)
    if format=='alphabetical' :
        total_seq.remove('R')
        total_seq.remove('K')
    freq = total_seq.count()
    freq.normalize()
    d = dist.Categorical(freq)
    l = len(total_seq)
    new_seq = d.sample(l)
    #similarcutting