diff --git a/common_dataset.py b/common_dataset.py index 5ab8a5f21583168a1371968847032ef911f2f156..19f35c97b153cdbd3a426cda491d13c1026e6d99 100644 --- a/common_dataset.py +++ b/common_dataset.py @@ -119,9 +119,10 @@ def zero_to_minus(arr): class Common_Dataset(Dataset): - def __init__(self, dataframe, length, pad=True, convert=True, vocab='unmod'): + def __init__(self, dataframe, length, pad=True, convert=True, vocab='unmod', file=False): print('Data loader Initialisation') self.data = dataframe.reset_index() + self.file_mode = file if pad : print('Padding') padding(self.data, 'Sequence', length) @@ -135,10 +136,16 @@ class Common_Dataset(Dataset): seq = self.data['Sequence'][index] rt = self.data['Retention time'][index] intensity = self.data['Spectra'][index] - charge = self.data['Charge'][index] + file = self.data['file'][index] + + if self.file_mode : + return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity), torch.tensor(file) + else : + return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity) - return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity) + def set_file_mode(self,b): + self.file_mode=b def __len__(self) -> int: return self.data.shape[0] diff --git a/data_viz.py b/data_viz.py index 17edd1bea66b1fa991c3cc46ac0898cdbced5003..8b71678f6c4db51efcc7ce8cfff93820dc79e32a 100644 --- a/data_viz.py +++ b/data_viz.py @@ -253,23 +253,23 @@ def add_length(dataframe): dataframe['length']=dataframe['seq'].map(fonc) -df = pd.read_csv('output/output_common_data_ISA.csv') +# df = pd.read_csv('output/output_common_data_ISA.csv') +# add_length(df) +# df['abs_error'] = np.abs(df['rt pred']-df['true rt']) +# histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_ISA_ISA.png') +# scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_ISA_ISA.png', color=True) +# histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_ISA_ISA.png') +# +# df = pd.read_csv('output/out_prosit_common.csv') +# add_length(df) +# df['abs_error'] = np.abs(df['rt pred']-df['true rt']) +# histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_prosit.png') +# scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_prosit.png', color=True) +# histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_prosit.png') + +df = pd.read_csv('output/out_common_transfereval.csv') add_length(df) df['abs_error'] = np.abs(df['rt pred']-df['true rt']) -histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_ISA_ISA.png') -scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_ISA_ISA.png', color=True) -histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_ISA_ISA.png') - -df = pd.read_csv('output/out_prosit_common.csv') -add_length(df) -df['abs_error'] = np.abs(df['rt pred']-df['true rt']) -histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_prosit.png') -scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_prosit.png', color=True) -histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_prosit.png') - -df = pd.read_csv('output/out_common_transfer.csv') -add_length(df) -df['abs_error'] = np.abs(df['rt pred']-df['true rt']) -histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_ISA.png') -scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_ISA.png', color=True) -histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_ISA.png') \ No newline at end of file +histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_ISA_eval.png') +scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_ISA_eval.png', color=True) +histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_ISA_eval.png') \ No newline at end of file diff --git a/main_custom.py b/main_custom.py index 16f55524b189cbe8c3f1b4bd44160c364289fc74..cd1538d15b126efd78d8a7b127d476565334ea75 100644 --- a/main_custom.py +++ b/main_custom.py @@ -278,8 +278,9 @@ def save_pred(model, data_val, forward, output_path): for param in model.parameters(): param.requires_grad = False if forward == 'both': - pred_rt, pred_int, seqs, charges, true_rt, true_int = [], [], [], [], [], [] - for seq, charge, rt, intensity in data_val: + pred_rt, pred_int, seqs, charges, true_rt, true_int, file_list = [], [], [], [], [], [], [] + data_val.data.set_file_mode(True) + for seq, charge, rt, intensity, file in data_val: rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() @@ -290,12 +291,14 @@ def save_pred(model, data_val, forward, output_path): charges.extend(charge.data.cpu().tolist()) true_rt.extend(rt.data.cpu().tolist()) true_int.extend(intensity.data.cpu().tolist()) + file_list.extend([file]) data_frame['rt pred'] = pred_rt data_frame['seq'] = seqs data_frame['pred int'] = pred_int data_frame['true rt'] = true_rt data_frame['true int'] = true_int data_frame['charge'] = charges + data_frame['file'] = file_list diff --git a/msms_processing.py b/msms_processing.py index 5705fc43b573026f3e058d664e12b41bbf6e3213..b202abb2aba113030d18362ddf0b73e11ed45881 100644 --- a/msms_processing.py +++ b/msms_processing.py @@ -83,11 +83,17 @@ def mscatter(x,y, ax=None, m=None, **kw): # 17/01 23/01 24/01 if __name__ == '__main__': data_1 = pd.read_pickle('database/data_DIA_16_01_aligned.pkl') + data_1['file']= 1 data_2 = pd.read_pickle('database/data_DIA_17_01_aligned.pkl') + data_2['file'] = 2 data_3 = pd.read_pickle('database/data_DIA_20_01_aligned.pkl') + data_3['file'] = 3 data_4 = pd.read_pickle('database/data_DIA_23_01_aligned.pkl') + data_4['file'] = 4 data_5 = pd.read_pickle('database/data_DIA_24_01_aligned.pkl') + data_5['file'] = 5 data_6 = pd.read_pickle('database/data_DIA_30_01_aligned.pkl') + data_6['file'] = 6 data = pd.concat([data_1, data_2, data_3, data_4, data_5, data_6], ignore_index=True) num_total = len(data)