diff --git a/config.py b/config.py index bd78b863f50b902d27412b2d1ceba02be9a02bfd..d2e5d2d71f1c0715b27b37cfc00187de72677fee 100644 --- a/config.py +++ b/config.py @@ -26,6 +26,9 @@ def load_args(): parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction) parser.add_argument('--activation', type=str,default='relu') + parser.add_argument('--seq_train', type=str, default='sequence') + parser.add_argument('--seq_test', type=str, default='sequence') + parser.add_argument('--seq_val', type=str, default='sequence') parser.add_argument('--n_head', type=int, default=1) args = parser.parse_args() diff --git a/data/data_processing.py b/data/data_processing.py index 705e83fdd2f13b712caaf9d87ed38078e61ea7c0..5f104e8210d28a015f0d9e31a8609fb6760502fa 100644 --- a/data/data_processing.py +++ b/data/data_processing.py @@ -5,9 +5,9 @@ from loess.loess_1d import loess_1d from constant import ALPHABET_UNMOD_REV -def align(dataset, reference, column_dataset, column_ref): - dataset_unique = dataset[['sequence',column_dataset]].groupby('sequence').mean() - reference_unique = reference[['mod_sequence',column_ref]].groupby('mod_sequence').mean() +def align(dataset, reference, column_dataset, column_ref, seq_data, seq_ref): + dataset_unique = dataset[[seq_data,column_dataset]].groupby(seq_data).mean() + reference_unique = reference[[seq_ref,column_ref]].groupby(seq_ref).mean() seq_ref = reference_unique.index seq_common = dataset_unique.index seq_ref = seq_ref.tolist() @@ -116,12 +116,12 @@ def numerical_to_alphabetical_str(s): def main(): ref = pd.read_csv('data_prosit/data.csv') df_ISA = pd.read_csv('data_ISA_mox/data_isa.csv') - df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled') + df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled','sequence', 'mod_sequence') df_ISA_aligned.to_csv('data_ISA_mox/data_aligned_isa.csv', index=False) ref = pd.read_csv('data_prosit/data_noc.csv') df_ISA = pd.read_csv('data_ISA_mox/data_isa_noc.csv') - df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled') + df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled','sequence', 'mod_sequence') df_ISA_aligned.to_csv('data_ISA_mox/data_aligned_isa_noc.csv', index=False) diff --git a/data/dataset.py b/data/dataset.py index 8174384a3545216880fe95de5d3af4b9a823f7eb..9afa113585f6054ff0b49e148adf7f356198a083 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -50,9 +50,9 @@ def alphabetical_to_numerical(seq): num.append(ALPHABET_UNMOD[seq[i + dec]]) else: if seq[i + dec + 1:i + dec + 4] == 'CaC': - num.append(21) + num.append(ALPHABET_UNMOD['CaC']) elif seq[i + dec + 1:i + dec + 4] == 'OxM': - num.append(22) + num.append(ALPHABET_UNMOD['OxM']) else: raise 'Modification not supported' dec += 4 @@ -61,7 +61,7 @@ def alphabetical_to_numerical(seq): class RT_Dataset(Dataset): - def __init__(self, size, data_source, mode, length, format='iRT_scaled'): + def __init__(self, size, data_source, mode, length, format='iRT_scaled', seq_col='sequence'): print('Data loader Initialisation') self.data = pd.read_csv(data_source) @@ -101,9 +101,9 @@ class RT_Dataset(Dataset): return self.data.shape[0] -def load_data(batch_size, data_source, length=25, mode='train', size=None): +def load_data(batch_size, data_source, length=25, mode='train', size=None, seq_col = 'sequence'): print('Loading data') - data = RT_Dataset(size = None, data_source=data_source, mode=mode, length=length) + data = RT_Dataset(size = None, data_source=data_source, mode=mode, length=length, seq_col=seq_col) data_loader = DataLoader(data, batch_size=batch_size, shuffle=True) return data_loader \ No newline at end of file diff --git a/data/msms_processing.py b/data/msms_processing.py index 0002f7e4413d94c8de96ffdc00652c93898e097d..b575bd565e4eb3bc26f1ecb1b256c0c5dbd01a02 100644 --- a/data/msms_processing.py +++ b/data/msms_processing.py @@ -76,19 +76,19 @@ def add_split_column(data, split=(0.7,0.15,0.15)): return data_split def main(): - # df_03_02 = load_data('data_ISA_mox/msms03_02.txt', 70) + df_03_02 = load_data('data_ISA_mox/msms_03_02.txt', 70) df_16_01 = load_data('data_ISA_mox/msms_16_01.txt', 70) df_20_01 = load_data('data_ISA_mox/msms_20_01.txt', 70) df_30_01 = load_data('data_ISA_mox/msms_30_01.txt', 70) - merged_df = pd.concat([df_20_01, df_30_01, df_16_01], ignore_index=True) + merged_df = pd.concat([df_20_01, df_30_01, df_16_01, df_03_02], ignore_index=True) final_df = add_split_column(merged_df) final_df.to_csv('data_ISA_mox/data_isa.csv', index=False) df2 = filter_cysteine(final_df) df2.to_csv('data_ISA_mox/data_isa_noc.csv', index=False) - final_df= pd.read_csv('data_prosit/data.csv') - df2 = filter_cysteine(final_df) - df2.to_csv('data_prosit/data_noc.csv', index=False) + # final_df= pd.read_csv('data_prosit/data.csv') + # df2 = filter_cysteine(final_df) + # df2.to_csv('data_prosit/data_noc.csv', index=False) if __name__ == '__main__': main() \ No newline at end of file diff --git a/main.py b/main.py index 316f779db81d428f444bd40e42e8ba06c7b7541a..5113e81b9b78b79636f103b745027e3353535f69 100644 --- a/main.py +++ b/main.py @@ -95,9 +95,9 @@ def main(args): print(args) print('Cuda : ', torch.cuda.is_available()) - data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train) - data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test) - data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val) + data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train, seq_col=args.seq_train) + data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test, seq_col=args.seq_test) + data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val, seq_col=args.seq_val) print('\nData loaded') model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,