import numpy as np import pandas as pb import pandas as pd import pyarrow.parquet as pq import pyarrow as pa import torch import matplotlib.pyplot as plt from model.model import ModelTransformer from config import load_args from data.dataset import load_data ALPHABET_UNMOD = { "": 0, "A": 1, "C": 2, "D": 3, "E": 4, "F": 5, "G": 6, "H": 7, "I": 8, "K": 9, "L": 10, "M": 11, "N": 12, "P": 13, "Q": 14, "R": 15, "S": 16, "T": 17, "V": 18, "W": 19, "Y": 20, "M(UniMod:35)": 21, "CaC": 22 } ALPHABET_UNMOD_REV = {v: k for k, v in ALPHABET_UNMOD.items()} def numerical_to_alphabetical_str(s): seq = '' s = s.replace('[','') s = s.replace(']', '') arr = s.split(',') arr = list(map(int, arr)) for i in range(len(arr)): seq+=ALPHABET_UNMOD_REV[arr[i]] return seq def load_lib(path): table = pq.read_table(path) table = table.to_pandas() return table def extract_sequence(data_frame): seq = data_frame['Modified.Sequence'] df_pred = pd.DataFrame(seq) df_pred.columns = ['sequence'] df_pred['sequence']=df_pred['sequence'].map(lambda x:x.replace('M(UniMod:35)','-OxM-')) df_pred['remove']=df_pred['sequence'].map((lambda x : 'U' in x)) df_pred = df_pred[df_pred['remove']==False] df_pred = df_pred[['sequence']] df_pred['irt_scaled']=0 df_pred['state'] = 'holdout' df_pred = df_pred.drop_duplicates() return df_pred def predict(data_pred, model, output_path): data_frame = pd.DataFrame() model.eval() for param in model.parameters(): param.requires_grad = False pred_rt, seqs, true_rt = [], [], [] for seq, rt in data_pred: rt = rt.float() if torch.cuda.is_available(): seq, rt = seq.cuda(), rt.cuda() pr_rt = model.forward(seq) pred_rt.extend(pr_rt.data.cpu().tolist()) seqs.extend(seq.data.cpu().tolist()) true_rt.extend(rt.data.cpu().tolist()) data_frame['rt pred'] = pred_rt data_frame['seq'] = seqs data_frame['true rt'] = true_rt data_frame.to_csv(output_path) if __name__ =='__main__': # df = load_lib('data/spectral_lib/first_lib.parquet') # # plt.hist(df['RT']) # plt.savefig('test.png') # # df_2 = pd.read_csv('data/data_prosit/data.csv') # # plt.clf() # plt.hist(df_2['irt']) # plt.savefig('test2.png') # df_2 = extract_sequence(df).reset_index(drop=True) # # pred = pd.read_csv('../output/out_uniprot_base.csv') # # pred['seq']=pred['seq'].map(numerical_to_alphabetical_str) # # pred['Modified.Sequence']=pred['seq'] # # result = pd.merge(df,pred[['Modified.Sequence','rt pred']],on='Modified.Sequence',how='left') # # result['RT']=result['rt pred'] # # result = result.drop('rt pred', axis=1) # # table = pa.Table.from_pandas(result) # # pq.write_table(table, 'spectral_lib/custom_first_lib.parquet') args = load_args() model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=30) if torch.cuda.is_available(): model = model.cuda() model.load_state_dict(torch.load(args.model_weigh, weights_only=True)) data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test, seq_col=args.seq_test) predict(data_test, model, args.output)