diff --git a/config.py b/config.py
index bd78b863f50b902d27412b2d1ceba02be9a02bfd..d2e5d2d71f1c0715b27b37cfc00187de72677fee 100644
--- a/config.py
+++ b/config.py
@@ -26,6 +26,9 @@ def load_args():
     parser.add_argument('--output', type=str, default='output/out.csv')
     parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction)
     parser.add_argument('--activation', type=str,default='relu')
+    parser.add_argument('--seq_train', type=str, default='sequence')
+    parser.add_argument('--seq_test', type=str, default='sequence')
+    parser.add_argument('--seq_val', type=str, default='sequence')
     parser.add_argument('--n_head', type=int, default=1)
     args = parser.parse_args()
 
diff --git a/data/data_processing.py b/data/data_processing.py
index 705e83fdd2f13b712caaf9d87ed38078e61ea7c0..5f104e8210d28a015f0d9e31a8609fb6760502fa 100644
--- a/data/data_processing.py
+++ b/data/data_processing.py
@@ -5,9 +5,9 @@ from loess.loess_1d import loess_1d
 from constant import ALPHABET_UNMOD_REV
 
 
-def align(dataset, reference, column_dataset, column_ref):
-    dataset_unique = dataset[['sequence',column_dataset]].groupby('sequence').mean()
-    reference_unique = reference[['mod_sequence',column_ref]].groupby('mod_sequence').mean()
+def align(dataset, reference, column_dataset, column_ref, seq_data, seq_ref):
+    dataset_unique = dataset[[seq_data,column_dataset]].groupby(seq_data).mean()
+    reference_unique = reference[[seq_ref,column_ref]].groupby(seq_ref).mean()
     seq_ref = reference_unique.index
     seq_common = dataset_unique.index
     seq_ref = seq_ref.tolist()
@@ -116,12 +116,12 @@ def numerical_to_alphabetical_str(s):
 def main():
     ref = pd.read_csv('data_prosit/data.csv')
     df_ISA = pd.read_csv('data_ISA_mox/data_isa.csv')
-    df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled')
+    df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled','sequence', 'mod_sequence')
     df_ISA_aligned.to_csv('data_ISA_mox/data_aligned_isa.csv', index=False)
 
     ref = pd.read_csv('data_prosit/data_noc.csv')
     df_ISA = pd.read_csv('data_ISA_mox/data_isa_noc.csv')
-    df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled')
+    df_ISA_aligned = align(df_ISA, ref, 'irt_scaled', 'irt_scaled','sequence', 'mod_sequence')
     df_ISA_aligned.to_csv('data_ISA_mox/data_aligned_isa_noc.csv', index=False)
 
 
diff --git a/data/dataset.py b/data/dataset.py
index 8174384a3545216880fe95de5d3af4b9a823f7eb..9afa113585f6054ff0b49e148adf7f356198a083 100644
--- a/data/dataset.py
+++ b/data/dataset.py
@@ -50,9 +50,9 @@ def alphabetical_to_numerical(seq):
             num.append(ALPHABET_UNMOD[seq[i + dec]])
         else:
             if seq[i + dec + 1:i + dec + 4] == 'CaC':
-                num.append(21)
+                num.append(ALPHABET_UNMOD['CaC'])
             elif seq[i + dec + 1:i + dec + 4] == 'OxM':
-                num.append(22)
+                num.append(ALPHABET_UNMOD['OxM'])
             else:
                 raise 'Modification not supported'
             dec += 4
@@ -61,7 +61,7 @@ def alphabetical_to_numerical(seq):
 
 class RT_Dataset(Dataset):
 
-    def __init__(self, size, data_source, mode, length, format='iRT_scaled'):
+    def __init__(self, size, data_source, mode, length, format='iRT_scaled', seq_col='sequence'):
         print('Data loader Initialisation')
         self.data = pd.read_csv(data_source)
 
@@ -101,9 +101,9 @@ class RT_Dataset(Dataset):
         return self.data.shape[0]
 
 
-def load_data(batch_size, data_source, length=25, mode='train', size=None):
+def load_data(batch_size, data_source, length=25, mode='train', size=None, seq_col = 'sequence'):
     print('Loading data')
-    data = RT_Dataset(size = None, data_source=data_source, mode=mode, length=length)
+    data = RT_Dataset(size = None, data_source=data_source, mode=mode, length=length, seq_col=seq_col)
     data_loader = DataLoader(data, batch_size=batch_size, shuffle=True)
 
     return data_loader
\ No newline at end of file
diff --git a/data/msms_processing.py b/data/msms_processing.py
index 0002f7e4413d94c8de96ffdc00652c93898e097d..b575bd565e4eb3bc26f1ecb1b256c0c5dbd01a02 100644
--- a/data/msms_processing.py
+++ b/data/msms_processing.py
@@ -76,19 +76,19 @@ def add_split_column(data, split=(0.7,0.15,0.15)):
     return data_split
 
 def main():
-    # df_03_02 = load_data('data_ISA_mox/msms03_02.txt', 70)
+    df_03_02 = load_data('data_ISA_mox/msms_03_02.txt', 70)
     df_16_01 = load_data('data_ISA_mox/msms_16_01.txt', 70)
     df_20_01 = load_data('data_ISA_mox/msms_20_01.txt', 70)
     df_30_01 = load_data('data_ISA_mox/msms_30_01.txt', 70)
-    merged_df = pd.concat([df_20_01, df_30_01, df_16_01], ignore_index=True)
+    merged_df = pd.concat([df_20_01, df_30_01, df_16_01, df_03_02], ignore_index=True)
     final_df = add_split_column(merged_df)
     final_df.to_csv('data_ISA_mox/data_isa.csv', index=False)
     df2 = filter_cysteine(final_df)
     df2.to_csv('data_ISA_mox/data_isa_noc.csv', index=False)
 
-    final_df= pd.read_csv('data_prosit/data.csv')
-    df2 = filter_cysteine(final_df)
-    df2.to_csv('data_prosit/data_noc.csv', index=False)
+    # final_df= pd.read_csv('data_prosit/data.csv')
+    # df2 = filter_cysteine(final_df)
+    # df2.to_csv('data_prosit/data_noc.csv', index=False)
 
 if __name__ == '__main__':
     main()
\ No newline at end of file
diff --git a/main.py b/main.py
index 316f779db81d428f444bd40e42e8ba06c7b7541a..5113e81b9b78b79636f103b745027e3353535f69 100644
--- a/main.py
+++ b/main.py
@@ -95,9 +95,9 @@ def main(args):
     print(args)
     print('Cuda : ', torch.cuda.is_available())
 
-    data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train)
-    data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test)
-    data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val)
+    data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train, seq_col=args.seq_train)
+    data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test, seq_col=args.seq_test)
+    data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val, seq_col=args.seq_val)
     print('\nData loaded')
 
     model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,