From cb65c471518d61d9d1b5cb90fa7528081eeec1fc Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Thu, 16 Jan 2025 11:00:23 +0100
Subject: [PATCH] data processing with variable mod OxM

---
 config.py               |  3 +++
 data/data_processing.py | 10 +++++-----
 data/dataset.py         | 10 +++++-----
 data/msms_processing.py | 10 +++++-----
 main.py                 |  6 +++---
 5 files changed, 21 insertions(+), 18 deletions(-)

diff --git a/config.py b/config.py
index bd78b86..d2e5d2d 100644
--- a/config.py
+++ b/config.py
@@ -26,6 +26,9 @@ def load_args():
     parser.add_argument('--output', type=str, default='output/out.csv')
     parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction)
     parser.add_argument('--activation', type=str,default='relu')
+    parser.add_argument('--seq_train', type=str, default='sequence')
+    parser.add_argument('--seq_test', type=str, default='sequence')
+    parser.add_argument('--seq_val', type=str, default='sequence')
     parser.add_argument('--n_head', type=int, default=1)
     args = parser.parse_args()
 
diff --git a/data/data_processing.py b/data/data_processing.py
index 705e83f..5f104e8 100644
--- a/data/data_processing.py
+++ b/data/data_processing.py
@@ -5,9 +5,9 @@ from loess.loess_1d import loess_1d
 from constant import ALPHABET_UNMOD_REV
 
 
-def align(dataset, reference, column_dataset, column_ref):
-    dataset_unique = dataset[['sequence',column_dataset]].groupby('sequence').mean()
-    reference_unique = reference[['mod_sequence',column_ref]].groupby('mod_sequence').mean()
+def align(dataset, reference, column_dataset, column_ref, seq_data, seq_ref):
+    dataset_unique = dataset[[seq_data,column_dataset]].groupby(seq_data).mean()
+    reference_unique = reference[[seq_ref,column_ref]].groupby(seq_ref).mean()
     seq_ref = reference_unique.index
     seq_common = dataset_unique.index
     seq_ref = seq_ref.tolist()
@@ -116,12 +116,12 @@ def numerical_to_alphabetical_str(s):
 def main():
     ref = pd.read_csv('data_prosit/data.csv')
     df_ISA = pd.read_csv('data_ISA_mox/data_isa.csv')
-    df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled')
+    df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled','sequence', 'mod_sequence')
     df_ISA_aligned.to_csv('data_ISA_mox/data_aligned_isa.csv', index=False)
 
     ref = pd.read_csv('data_prosit/data_noc.csv')
     df_ISA = pd.read_csv('data_ISA_mox/data_isa_noc.csv')
-    df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled')
+    df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled','sequence', 'mod_sequence')
     df_ISA_aligned.to_csv('data_ISA_mox/data_aligned_isa_noc.csv', index=False)
 
 
diff --git a/data/dataset.py b/data/dataset.py
index 8174384..9afa113 100644
--- a/data/dataset.py
+++ b/data/dataset.py
@@ -50,9 +50,9 @@ def alphabetical_to_numerical(seq):
             num.append(ALPHABET_UNMOD[seq[i + dec]])
         else:
             if seq[i + dec + 1:i + dec + 4] == 'CaC':
-                num.append(21)
+                num.append(ALPHABET_UNMOD['CaC'])
             elif seq[i + dec + 1:i + dec + 4] == 'OxM':
-                num.append(22)
+                num.append(ALPHABET_UNMOD['OxM'])
             else:
                 raise 'Modification not supported'
             dec += 4
@@ -61,7 +61,7 @@ def alphabetical_to_numerical(seq):
 
 class RT_Dataset(Dataset):
 
-    def __init__(self, size, data_source, mode, length, format='iRT_scaled'):
+    def __init__(self, size, data_source, mode, length, format='iRT_scaled', seq_col='sequence'):
         print('Data loader Initialisation')
         self.data = pd.read_csv(data_source)
 
@@ -101,9 +101,9 @@ class RT_Dataset(Dataset):
         return self.data.shape[0]
 
 
-def load_data(batch_size, data_source, length=25, mode='train', size=None):
+def load_data(batch_size, data_source, length=25, mode='train', size=None, seq_col = 'sequence'):
     print('Loading data')
-    data = RT_Dataset(size = None, data_source=data_source, mode=mode, length=length)
+    data = RT_Dataset(size = None, data_source=data_source, mode=mode, length=length, seq_col=seq_col)
     data_loader = DataLoader(data, batch_size=batch_size, shuffle=True)
 
     return data_loader
\ No newline at end of file
diff --git a/data/msms_processing.py b/data/msms_processing.py
index 0002f7e..b575bd5 100644
--- a/data/msms_processing.py
+++ b/data/msms_processing.py
@@ -76,19 +76,19 @@ def add_split_column(data, split=(0.7,0.15,0.15)):
     return data_split
 
 def main():
-    # df_03_02 = load_data('data_ISA_mox/msms03_02.txt', 70)
+    df_03_02 = load_data('data_ISA_mox/msms_03_02.txt', 70)
     df_16_01 = load_data('data_ISA_mox/msms_16_01.txt', 70)
     df_20_01 = load_data('data_ISA_mox/msms_20_01.txt', 70)
     df_30_01 = load_data('data_ISA_mox/msms_30_01.txt', 70)
-    merged_df = pd.concat([df_20_01, df_30_01, df_16_01], ignore_index=True)
+    merged_df = pd.concat([df_20_01, df_30_01, df_16_01, df_03_02], ignore_index=True)
     final_df = add_split_column(merged_df)
     final_df.to_csv('data_ISA_mox/data_isa.csv', index=False)
     df2 = filter_cysteine(final_df)
     df2.to_csv('data_ISA_mox/data_isa_noc.csv', index=False)
 
-    final_df= pd.read_csv('data_prosit/data.csv')
-    df2 = filter_cysteine(final_df)
-    df2.to_csv('data_prosit/data_noc.csv', index=False)
+    # final_df= pd.read_csv('data_prosit/data.csv')
+    # df2 = filter_cysteine(final_df)
+    # df2.to_csv('data_prosit/data_noc.csv', index=False)
 
 if __name__ == '__main__':
     main()
\ No newline at end of file
diff --git a/main.py b/main.py
index 316f779..5113e81 100644
--- a/main.py
+++ b/main.py
@@ -95,9 +95,9 @@ def main(args):
     print(args)
     print('Cuda : ', torch.cuda.is_available())
 
-    data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train)
-    data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test)
-    data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val)
+    data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train, seq_col=args.seq_train)
+    data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test, seq_col=args.seq_test)
+    data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val, seq_col=args.seq_val)
     print('\nData loaded')
 
     model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
-- 
GitLab