From cb65c471518d61d9d1b5cb90fa7528081eeec1fc Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Thu, 16 Jan 2025 11:00:23 +0100 Subject: [PATCH] data processing with variable mod OxM --- config.py | 3 +++ data/data_processing.py | 10 +++++----- data/dataset.py | 10 +++++----- data/msms_processing.py | 10 +++++----- main.py | 6 +++--- 5 files changed, 21 insertions(+), 18 deletions(-) diff --git a/config.py b/config.py index bd78b86..d2e5d2d 100644 --- a/config.py +++ b/config.py @@ -26,6 +26,9 @@ def load_args(): parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction) parser.add_argument('--activation', type=str,default='relu') + parser.add_argument('--seq_train', type=str, default='sequence') + parser.add_argument('--seq_test', type=str, default='sequence') + parser.add_argument('--seq_val', type=str, default='sequence') parser.add_argument('--n_head', type=int, default=1) args = parser.parse_args() diff --git a/data/data_processing.py b/data/data_processing.py index 705e83f..5f104e8 100644 --- a/data/data_processing.py +++ b/data/data_processing.py @@ -5,9 +5,9 @@ from loess.loess_1d import loess_1d from constant import ALPHABET_UNMOD_REV -def align(dataset, reference, column_dataset, column_ref): - dataset_unique = dataset[['sequence',column_dataset]].groupby('sequence').mean() - reference_unique = reference[['mod_sequence',column_ref]].groupby('mod_sequence').mean() +def align(dataset, reference, column_dataset, column_ref, seq_data, seq_ref): + dataset_unique = dataset[[seq_data,column_dataset]].groupby(seq_data).mean() + reference_unique = reference[[seq_ref,column_ref]].groupby(seq_ref).mean() seq_ref = reference_unique.index seq_common = dataset_unique.index seq_ref = seq_ref.tolist() @@ -116,12 +116,12 @@ def numerical_to_alphabetical_str(s): def main(): ref = pd.read_csv('data_prosit/data.csv') df_ISA = pd.read_csv('data_ISA_mox/data_isa.csv') - df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled') + df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled','sequence', 'mod_sequence') df_ISA_aligned.to_csv('data_ISA_mox/data_aligned_isa.csv', index=False) ref = pd.read_csv('data_prosit/data_noc.csv') df_ISA = pd.read_csv('data_ISA_mox/data_isa_noc.csv') - df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled') + df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled','sequence', 'mod_sequence') df_ISA_aligned.to_csv('data_ISA_mox/data_aligned_isa_noc.csv', index=False) diff --git a/data/dataset.py b/data/dataset.py index 8174384..9afa113 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -50,9 +50,9 @@ def alphabetical_to_numerical(seq): num.append(ALPHABET_UNMOD[seq[i + dec]]) else: if seq[i + dec + 1:i + dec + 4] == 'CaC': - num.append(21) + num.append(ALPHABET_UNMOD['CaC']) elif seq[i + dec + 1:i + dec + 4] == 'OxM': - num.append(22) + num.append(ALPHABET_UNMOD['OxM']) else: raise 'Modification not supported' dec += 4 @@ -61,7 +61,7 @@ def alphabetical_to_numerical(seq): class RT_Dataset(Dataset): - def __init__(self, size, data_source, mode, length, format='iRT_scaled'): + def __init__(self, size, data_source, mode, length, format='iRT_scaled', seq_col='sequence'): print('Data loader Initialisation') self.data = pd.read_csv(data_source) @@ -101,9 +101,9 @@ class RT_Dataset(Dataset): return self.data.shape[0] -def load_data(batch_size, data_source, length=25, mode='train', size=None): +def load_data(batch_size, data_source, length=25, mode='train', size=None, seq_col = 'sequence'): print('Loading data') - data = RT_Dataset(size = None, data_source=data_source, mode=mode, length=length) + data = RT_Dataset(size = None, data_source=data_source, mode=mode, length=length, seq_col=seq_col) data_loader = DataLoader(data, batch_size=batch_size, shuffle=True) return data_loader \ No newline at end of file diff --git a/data/msms_processing.py b/data/msms_processing.py index 0002f7e..b575bd5 100644 --- a/data/msms_processing.py +++ b/data/msms_processing.py @@ -76,19 +76,19 @@ def add_split_column(data, split=(0.7,0.15,0.15)): return data_split def main(): - # df_03_02 = load_data('data_ISA_mox/msms03_02.txt', 70) + df_03_02 = load_data('data_ISA_mox/msms_03_02.txt', 70) df_16_01 = load_data('data_ISA_mox/msms_16_01.txt', 70) df_20_01 = load_data('data_ISA_mox/msms_20_01.txt', 70) df_30_01 = load_data('data_ISA_mox/msms_30_01.txt', 70) - merged_df = pd.concat([df_20_01, df_30_01, df_16_01], ignore_index=True) + merged_df = pd.concat([df_20_01, df_30_01, df_16_01, df_03_02], ignore_index=True) final_df = add_split_column(merged_df) final_df.to_csv('data_ISA_mox/data_isa.csv', index=False) df2 = filter_cysteine(final_df) df2.to_csv('data_ISA_mox/data_isa_noc.csv', index=False) - final_df= pd.read_csv('data_prosit/data.csv') - df2 = filter_cysteine(final_df) - df2.to_csv('data_prosit/data_noc.csv', index=False) + # final_df= pd.read_csv('data_prosit/data.csv') + # df2 = filter_cysteine(final_df) + # df2.to_csv('data_prosit/data_noc.csv', index=False) if __name__ == '__main__': main() \ No newline at end of file diff --git a/main.py b/main.py index 316f779..5113e81 100644 --- a/main.py +++ b/main.py @@ -95,9 +95,9 @@ def main(args): print(args) print('Cuda : ', torch.cuda.is_available()) - data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train) - data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test) - data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val) + data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train, seq_col=args.seq_train) + data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test, seq_col=args.seq_test) + data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val, seq_col=args.seq_val) print('\nData loaded') model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, -- GitLab