Skip to content
Snippets Groups Projects
Commit 60f04719 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix

parent 10b5e5f1
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ import pyarrow.parquet as pq ...@@ -5,7 +5,7 @@ import pyarrow.parquet as pq
import pyarrow as pa import pyarrow as pa
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# from loess.loess_1d import loess_1d from loess.loess_1d import loess_1d
from model.model import ModelTransformer from model.model import ModelTransformer
from config import load_args from config import load_args
...@@ -96,98 +96,94 @@ def predict(data_pred, model, output_path): ...@@ -96,98 +96,94 @@ def predict(data_pred, model, output_path):
if __name__ =='__main__': if __name__ =='__main__':
# df = load_lib('spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.parquet') df = load_lib('spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.parquet')
# df = extract_sequence(df)
# df.to_csv('spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.csv')
# plt.hist(df['RT'])
# plt.savefig('test.png')
#
# df_2 = pd.read_csv('data_prosit/data.csv')
#
# plt.clf()
# plt.hist(df_2['irt'])
# plt.savefig('test2.png')
#
# df_2 = extract_sequence(df).reset_index(drop=True)
#
# pred = pd.read_csv('../output/out_lib_CITBASE_try_contaminant.csv')
#
# pred['seq']=pred['seq'].map(numerical_to_alphabetical_str)
#
# pred['Modified.Sequence']=pred['seq']
#
# result = pd.merge(df,pred[['Modified.Sequence','rt pred']],on='Modified.Sequence',how='left')
#
#
#
# #alignement
#
# ref = pd.read_csv('data_prosit/data_noc.csv')
# df_ISA = pd.read_csv('data_ISA/data_aligned_isa_noc.csv')
#
# dataset, reference, column_dataset, column_ref, seq_data, seq_ref = df_ISA, ref, 'irt_scaled', 'irt', 'sequence','sequence',
#
# dataset_ref=dataset[dataset['state']=='train']
# dataset_unique = dataset_ref[[seq_data,column_dataset]].groupby(seq_data).mean()
# print('unique',len(dataset_unique))
# reference_unique = reference[[seq_ref,column_ref]].groupby(seq_ref).mean()
# seq_ref = reference_unique.index
# seq_common = dataset_unique.index
# seq_ref = seq_ref.tolist()
# seq_common = seq_common.tolist()
#
# seq_ref = [tuple(l) for l in seq_ref]
# seq_common = [tuple(l) for l in seq_common]
#
# ind_dict_ref = dict((k, i) for i, k in enumerate(seq_ref))
# inter = set(ind_dict_ref).intersection(seq_common)
# print(len(inter))
#
# ind_dict_ref = [ind_dict_ref[x] for x in inter]
#
# indices_common = dict((k, i) for i, k in enumerate(seq_common))
# indices_common = [indices_common[x] for x in inter]
#
#
# rt_ref = reference_unique[column_ref][ind_dict_ref].reset_index()
# rt_data = dataset_unique[column_dataset][indices_common].reset_index()
#
# plt.scatter(rt_data[column_dataset].tolist(),rt_ref[column_ref].tolist(),s=0.1)
# plt.savefig('test.png')
#
# #présence de NAN qui casse le réalignement (solution temporaire : remplacer par 0.
# result['rt pred']=result['rt pred'].fillna(value=0)
# xout, yout, wout = loess_1d(np.array(rt_data[column_dataset].tolist()), np.array(rt_ref[column_ref].tolist()),
# xnew=result['rt pred'],
# degree=1,
# npoints=None, rotate=False, sigy=None)
#
#
# #writing results
#
# result['RT'] = yout
#
# result = result.drop('rt pred', axis=1)
#
# table = pa.Table.from_pandas(result)
#
# pq.write_table(table, 'spectral_lib/first_lib_contaminant_prosit_aligned.parquet')
#
df_2 = pd.read_csv('data_prosit/data.csv')
args = load_args() plt.clf()
plt.hist(df_2['irt'])
plt.savefig('test2.png')
model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, df_2 = extract_sequence(df).reset_index(drop=True)
n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=30)
if torch.cuda.is_available(): pred = pd.read_csv('../output/out_transfer_prosit_isa_1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.csv')
model = model.cuda()
model.load_state_dict(torch.load(args.model_weigh, weights_only=True)) pred['seq']=pred['seq'].map(numerical_to_alphabetical_str)
data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test, pred['Modified.Sequence']=pred['seq']
seq_col=args.seq_test)
predict(data_test, model, args.output) result = pd.merge(df,pred[['Modified.Sequence','rt pred']],on='Modified.Sequence',how='left')
#alignement
ref = pd.read_csv('data_prosit/data_noc.csv')
df_ISA = pd.read_csv('data_ISA/data_aligned_isa_noc.csv')
dataset, reference, column_dataset, column_ref, seq_data, seq_ref = df_ISA, ref, 'irt_scaled', 'irt', 'sequence','sequence',
dataset_ref=dataset[dataset['state']=='train']
dataset_unique = dataset_ref[[seq_data,column_dataset]].groupby(seq_data).mean()
print('unique',len(dataset_unique))
reference_unique = reference[[seq_ref,column_ref]].groupby(seq_ref).mean()
seq_ref = reference_unique.index
seq_common = dataset_unique.index
seq_ref = seq_ref.tolist()
seq_common = seq_common.tolist()
seq_ref = [tuple(l) for l in seq_ref]
seq_common = [tuple(l) for l in seq_common]
ind_dict_ref = dict((k, i) for i, k in enumerate(seq_ref))
inter = set(ind_dict_ref).intersection(seq_common)
print(len(inter))
ind_dict_ref = [ind_dict_ref[x] for x in inter]
indices_common = dict((k, i) for i, k in enumerate(seq_common))
indices_common = [indices_common[x] for x in inter]
rt_ref = reference_unique[column_ref][ind_dict_ref].reset_index()
rt_data = dataset_unique[column_dataset][indices_common].reset_index()
plt.scatter(rt_data[column_dataset].tolist(),rt_ref[column_ref].tolist(),s=0.1)
plt.savefig('test.png')
#présence de NAN qui casse le réalignement (solution temporaire : remplacer par 0.
result['rt pred']=result['rt pred'].fillna(value=0)
xout, yout, wout = loess_1d(np.array(rt_data[column_dataset].tolist()), np.array(rt_ref[column_ref].tolist()),
xnew=result['rt pred'],
degree=1,
npoints=None, rotate=False, sigy=None)
#writing results
result['RT'] = yout
result = result.drop('rt pred', axis=1)
table = pa.Table.from_pandas(result)
pq.write_table(table, 'spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang_finetune_aligned.parquet')
# args = load_args()
#
# model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
# n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
# decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
# embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=30)
#
# if torch.cuda.is_available():
# model = model.cuda()
#
# model.load_state_dict(torch.load(args.model_weigh, weights_only=True))
#
# data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test,
# seq_col=args.seq_test)
#
# predict(data_test, model, args.output)
...@@ -41,7 +41,19 @@ def compare_error(path_1,path_2): ...@@ -41,7 +41,19 @@ def compare_error(path_1,path_2):
plt.savefig('error2.png') plt.savefig('error2.png')
return error_1,error_2 return error_1,error_2
def compare_with_db(path):
df = pd.read_csv(path, sep='\t', encoding='latin-1')
df_ref = pd.read_excel('250205_All_Peptides_panel_ID_+_RES.xlsx',names=['peptide','fonction'])
df2=df[df['Stripped.Sequence'].isin(df_ref['peptide'].to_list())]
corespondance = pd.merge(df2,df_ref,left_on='Stripped.Sequence',right_on='peptide',how="left")
return corespondance
if __name__ == '__main__': if __name__ == '__main__':
# compare_id('CITCRO_ANA_3/report_custom.tsv', 'CITCRO_ANA_3/report_first_lib.tsv', 'CITCRO_ANA_3/report_finetune.tsv','CITCRO_ANA_3') # compare_id('CITAMA_ANA_5/julie_custom_nolib.tsv', 'CITAMA_ANA_5/julie_base_nolib.tsv', 'CITAMA_ANA_5/julie_finetune_nolib.tsv','CITAMA_ANA_5_julie_no_lib')
e1,e2 = compare_error('CITCRO_ANA_3/report_custom.tsv', 'CITCRO_ANA_3/report_first_lib.tsv') # e1,e2 = compare_error('CITAMA_ANA_5/report_custom.tsv', 'CITCRO_ANA_3/report_first_lib.tsv')
\ No newline at end of file
cor_base = compare_with_db('CITAMA_ANA_5/julie_base_nolib.tsv')
cor_custom = compare_with_db('CITAMA_ANA_5/julie_custom_nolib.tsv')
cor_finetune = compare_with_db('CITAMA_ANA_5/julie_finetune_nolib.tsv')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment