diff --git a/diann_lib_processing.py b/diann_lib_processing.py index 483c773060ccfe9ed7e426485a2e779fc273e331..d5edc205c9fb00cadb288ec52b36d2cc4b6a9c96 100644 --- a/diann_lib_processing.py +++ b/diann_lib_processing.py @@ -5,7 +5,7 @@ import pyarrow.parquet as pq import pyarrow as pa import torch import matplotlib.pyplot as plt -# from loess.loess_1d import loess_1d +from loess.loess_1d import loess_1d from model.model import ModelTransformer from config import load_args @@ -96,98 +96,94 @@ def predict(data_pred, model, output_path): if __name__ =='__main__': - # df = load_lib('spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.parquet') - # df = extract_sequence(df) - # df.to_csv('spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.csv') - # plt.hist(df['RT']) - # plt.savefig('test.png') - # - # df_2 = pd.read_csv('data_prosit/data.csv') - # - # plt.clf() - # plt.hist(df_2['irt']) - # plt.savefig('test2.png') - # - # df_2 = extract_sequence(df).reset_index(drop=True) - # - # pred = pd.read_csv('../output/out_lib_CITBASE_try_contaminant.csv') - # - # pred['seq']=pred['seq'].map(numerical_to_alphabetical_str) - # - # pred['Modified.Sequence']=pred['seq'] - # - # result = pd.merge(df,pred[['Modified.Sequence','rt pred']],on='Modified.Sequence',how='left') - # - # - # - # #alignement - # - # ref = pd.read_csv('data_prosit/data_noc.csv') - # df_ISA = pd.read_csv('data_ISA/data_aligned_isa_noc.csv') - # - # dataset, reference, column_dataset, column_ref, seq_data, seq_ref = df_ISA, ref, 'irt_scaled', 'irt', 'sequence','sequence', - # - # dataset_ref=dataset[dataset['state']=='train'] - # dataset_unique = dataset_ref[[seq_data,column_dataset]].groupby(seq_data).mean() - # print('unique',len(dataset_unique)) - # reference_unique = reference[[seq_ref,column_ref]].groupby(seq_ref).mean() - # seq_ref = reference_unique.index - # seq_common = dataset_unique.index - # seq_ref = seq_ref.tolist() - # seq_common = seq_common.tolist() - # - # seq_ref = [tuple(l) for l in seq_ref] - # seq_common = [tuple(l) for l in seq_common] - # - # ind_dict_ref = dict((k, i) for i, k in enumerate(seq_ref)) - # inter = set(ind_dict_ref).intersection(seq_common) - # print(len(inter)) - # - # ind_dict_ref = [ind_dict_ref[x] for x in inter] - # - # indices_common = dict((k, i) for i, k in enumerate(seq_common)) - # indices_common = [indices_common[x] for x in inter] - # - # - # rt_ref = reference_unique[column_ref][ind_dict_ref].reset_index() - # rt_data = dataset_unique[column_dataset][indices_common].reset_index() - # - # plt.scatter(rt_data[column_dataset].tolist(),rt_ref[column_ref].tolist(),s=0.1) - # plt.savefig('test.png') - # - # #présence de NAN qui casse le réalignement (solution temporaire : remplacer par 0. - # result['rt pred']=result['rt pred'].fillna(value=0) - # xout, yout, wout = loess_1d(np.array(rt_data[column_dataset].tolist()), np.array(rt_ref[column_ref].tolist()), - # xnew=result['rt pred'], - # degree=1, - # npoints=None, rotate=False, sigy=None) - # - # - # #writing results - # - # result['RT'] = yout - # - # result = result.drop('rt pred', axis=1) - # - # table = pa.Table.from_pandas(result) - # - # pq.write_table(table, 'spectral_lib/first_lib_contaminant_prosit_aligned.parquet') - # + df = load_lib('spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.parquet') + df_2 = pd.read_csv('data_prosit/data.csv') - args = load_args() + plt.clf() + plt.hist(df_2['irt']) + plt.savefig('test2.png') - model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, - n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, - decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, - embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=30) + df_2 = extract_sequence(df).reset_index(drop=True) - if torch.cuda.is_available(): - model = model.cuda() + pred = pd.read_csv('../output/out_transfer_prosit_isa_1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.csv') - model.load_state_dict(torch.load(args.model_weigh, weights_only=True)) + pred['seq']=pred['seq'].map(numerical_to_alphabetical_str) - data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test, - seq_col=args.seq_test) + pred['Modified.Sequence']=pred['seq'] - predict(data_test, model, args.output) + result = pd.merge(df,pred[['Modified.Sequence','rt pred']],on='Modified.Sequence',how='left') + + + + #alignement + + ref = pd.read_csv('data_prosit/data_noc.csv') + df_ISA = pd.read_csv('data_ISA/data_aligned_isa_noc.csv') + + dataset, reference, column_dataset, column_ref, seq_data, seq_ref = df_ISA, ref, 'irt_scaled', 'irt', 'sequence','sequence', + + dataset_ref=dataset[dataset['state']=='train'] + dataset_unique = dataset_ref[[seq_data,column_dataset]].groupby(seq_data).mean() + print('unique',len(dataset_unique)) + reference_unique = reference[[seq_ref,column_ref]].groupby(seq_ref).mean() + seq_ref = reference_unique.index + seq_common = dataset_unique.index + seq_ref = seq_ref.tolist() + seq_common = seq_common.tolist() + + seq_ref = [tuple(l) for l in seq_ref] + seq_common = [tuple(l) for l in seq_common] + + ind_dict_ref = dict((k, i) for i, k in enumerate(seq_ref)) + inter = set(ind_dict_ref).intersection(seq_common) + print(len(inter)) + + ind_dict_ref = [ind_dict_ref[x] for x in inter] + + indices_common = dict((k, i) for i, k in enumerate(seq_common)) + indices_common = [indices_common[x] for x in inter] + + + rt_ref = reference_unique[column_ref][ind_dict_ref].reset_index() + rt_data = dataset_unique[column_dataset][indices_common].reset_index() + + plt.scatter(rt_data[column_dataset].tolist(),rt_ref[column_ref].tolist(),s=0.1) + plt.savefig('test.png') + + #présence de NAN qui casse le réalignement (solution temporaire : remplacer par 0. + result['rt pred']=result['rt pred'].fillna(value=0) + xout, yout, wout = loess_1d(np.array(rt_data[column_dataset].tolist()), np.array(rt_ref[column_ref].tolist()), + xnew=result['rt pred'], + degree=1, + npoints=None, rotate=False, sigy=None) + + + #writing results + + result['RT'] = yout + + result = result.drop('rt pred', axis=1) + + table = pa.Table.from_pandas(result) + + pq.write_table(table, 'spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang_finetune_aligned.parquet') + + + + # args = load_args() + # + # model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, + # n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, + # decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, + # embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=30) + # + # if torch.cuda.is_available(): + # model = model.cuda() + # + # model.load_state_dict(torch.load(args.model_weigh, weights_only=True)) + # + # data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test, + # seq_col=args.seq_test) + # + # predict(data_test, model, args.output) diff --git a/identification/result_extraction.py b/identification/result_extraction.py index e144a0ff7a3fbb38a5b4e46fc8abbc0ce4ed9fe2..3ef2a9a778d79843ffdd7078032ba094ea6337ac 100644 --- a/identification/result_extraction.py +++ b/identification/result_extraction.py @@ -41,7 +41,19 @@ def compare_error(path_1,path_2): plt.savefig('error2.png') return error_1,error_2 +def compare_with_db(path): + df = pd.read_csv(path, sep='\t', encoding='latin-1') + df_ref = pd.read_excel('250205_All_Peptides_panel_ID_+_RES.xlsx',names=['peptide','fonction']) + df2=df[df['Stripped.Sequence'].isin(df_ref['peptide'].to_list())] + corespondance = pd.merge(df2,df_ref,left_on='Stripped.Sequence',right_on='peptide',how="left") + + return corespondance + if __name__ == '__main__': - # compare_id('CITCRO_ANA_3/report_custom.tsv', 'CITCRO_ANA_3/report_first_lib.tsv', 'CITCRO_ANA_3/report_finetune.tsv','CITCRO_ANA_3') - e1,e2 = compare_error('CITCRO_ANA_3/report_custom.tsv', 'CITCRO_ANA_3/report_first_lib.tsv') \ No newline at end of file + # compare_id('CITAMA_ANA_5/julie_custom_nolib.tsv', 'CITAMA_ANA_5/julie_base_nolib.tsv', 'CITAMA_ANA_5/julie_finetune_nolib.tsv','CITAMA_ANA_5_julie_no_lib') + # e1,e2 = compare_error('CITAMA_ANA_5/report_custom.tsv', 'CITCRO_ANA_3/report_first_lib.tsv') + + cor_base = compare_with_db('CITAMA_ANA_5/julie_base_nolib.tsv') + cor_custom = compare_with_db('CITAMA_ANA_5/julie_custom_nolib.tsv') + cor_finetune = compare_with_db('CITAMA_ANA_5/julie_finetune_nolib.tsv') \ No newline at end of file