import numpy as np
import pandas as pb
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
import torch
import matplotlib.pyplot as plt
from model.model import ModelTransformer
from config import load_args
from data.dataset import load_data

ALPHABET_UNMOD = {
    "": 0,
    "A": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "K": 9,
    "L": 10,
    "M": 11,
    "N": 12,
    "P": 13,
    "Q": 14,
    "R": 15,
    "S": 16,
    "T": 17,
    "V": 18,
    "W": 19,
    "Y": 20,
    "M(UniMod:35)": 21,
    "CaC": 22
}

ALPHABET_UNMOD_REV = {v: k for k, v in ALPHABET_UNMOD.items()}

def numerical_to_alphabetical_str(s):
    seq = ''
    s = s.replace('[','')
    s = s.replace(']', '')
    arr = s.split(',')
    arr = list(map(int, arr))
    for i in range(len(arr)):
        seq+=ALPHABET_UNMOD_REV[arr[i]]
    return seq

def load_lib(path):
    table = pq.read_table(path)
    table = table.to_pandas()

    return table

def extract_sequence(data_frame):

    seq = data_frame['Modified.Sequence']

    df_pred = pd.DataFrame(seq)
    df_pred.columns = ['sequence']
    df_pred['sequence']=df_pred['sequence'].map(lambda x:x.replace('M(UniMod:35)','-OxM-'))
    df_pred['remove']=df_pred['sequence'].map((lambda x : 'U' in x))
    df_pred = df_pred[df_pred['remove']==False]
    df_pred = df_pred[['sequence']]
    df_pred['irt_scaled']=0
    df_pred['state'] = 'holdout'

    df_pred = df_pred.drop_duplicates()

    return df_pred

def predict(data_pred, model, output_path):
    data_frame = pd.DataFrame()
    model.eval()
    for param in model.parameters():
        param.requires_grad = False

    pred_rt, seqs, true_rt = [], [], []
    for seq, rt in data_pred:
        rt = rt.float()
        if torch.cuda.is_available():
            seq, rt = seq.cuda(), rt.cuda()
        pr_rt = model.forward(seq)
        pred_rt.extend(pr_rt.data.cpu().tolist())
        seqs.extend(seq.data.cpu().tolist())
        true_rt.extend(rt.data.cpu().tolist())

    data_frame['rt pred'] = pred_rt
    data_frame['seq'] = seqs
    data_frame['true rt'] = true_rt
    data_frame.to_csv(output_path)


if __name__ =='__main__':
    # df = load_lib('data/spectral_lib/first_lib.parquet')
    #
    # plt.hist(df['RT'])
    # plt.savefig('test.png')
    #
    # df_2 = pd.read_csv('data/data_prosit/data.csv')
    #
    # plt.clf()
    # plt.hist(df_2['irt'])
    # plt.savefig('test2.png')

    # df_2 = extract_sequence(df).reset_index(drop=True)
    #
    # pred = pd.read_csv('../output/out_uniprot_base.csv')
    #
    # pred['seq']=pred['seq'].map(numerical_to_alphabetical_str)
    #
    # pred['Modified.Sequence']=pred['seq']
    #
    # result = pd.merge(df,pred[['Modified.Sequence','rt pred']],on='Modified.Sequence',how='left')
    #
    # result['RT']=result['rt pred']
    #
    # result = result.drop('rt pred', axis=1)
    #
    # table = pa.Table.from_pandas(result)
    #
    # pq.write_table(table, 'spectral_lib/custom_first_lib.parquet')



    args = load_args()

    model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
                             n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
                             decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
                             embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=30)

    if torch.cuda.is_available():
        model = model.cuda()

    model.load_state_dict(torch.load(args.model_weigh, weights_only=True))

    data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test,
                          seq_col=args.seq_test)

    predict(data_test, model, args.output)