diff --git a/ELN/ELN.ods b/ELN/ELN.ods
new file mode 100644
index 0000000000000000000000000000000000000000..00f5e025a3a7f9d23b7bd23fd2ea9bed275ff2f0
Binary files /dev/null and b/ELN/ELN.ods differ
diff --git a/ProteinBert/config.json b/ProteinBert/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..2cf0cbf700474a7f5a2e3adaf17ce63d37dd539e
--- /dev/null
+++ b/ProteinBert/config.json
@@ -0,0 +1,23 @@
+{
+  "attention_probs_dropout_prob": 0.1,
+  "base_model": "transformer",
+  "finetuning_task": null,
+  "hidden_act": "gelu",
+  "hidden_dropout_prob": 0.1,
+  "hidden_size": 768,
+  "initializer_range": 0.02,
+  "input_size": 768,
+  "intermediate_size": 3072,
+  "layer_norm_eps": 1e-12,
+  "max_position_embeddings": 8192,
+  "num_attention_heads": 12,
+  "num_hidden_layers": 12,
+  "num_labels": 2,
+  "output_attentions": false,
+  "output_hidden_states": false,
+  "output_size": 768,
+  "pruned_heads": {},
+  "torchscript": false,
+  "type_vocab_size": 1,
+  "vocab_size": 30
+}
diff --git a/ProteinBert/pytorch_model.bin b/ProteinBert/pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..108f54af3c7292d34d412d94a59a1579177a0a14
Binary files /dev/null and b/ProteinBert/pytorch_model.bin differ
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..139cefffe52066288a7e925af9300f1d050b3dce
--- /dev/null
+++ b/README.md
@@ -0,0 +1,6 @@
+# LC-MS-RT-prediction
+
+
+First test for proteomics encoding trough RT and peak intensity prediction as pretext task. 
+
+This is part of finding a common latent space for FASTA file encoding and spectrograms data.
diff --git a/common_dataset.py b/common_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e468458a2e17c03a65e851ece6db955df3f9bb6
--- /dev/null
+++ b/common_dataset.py
@@ -0,0 +1,159 @@
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader
+import pandas as pd
+
+ALPHABET_UNMOD = {
+    "_": 0,
+    "A": 1,
+    "C": 2,
+    "D": 3,
+    "E": 4,
+    "F": 5,
+    "G": 6,
+    "H": 7,
+    "I": 8,
+    "K": 9,
+    "L": 10,
+    "M": 11,
+    "N": 12,
+    "P": 13,
+    "Q": 14,
+    "R": 15,
+    "S": 16,
+    "T": 17,
+    "V": 18,
+    "W": 19,
+    "Y": 20,
+    "CaC": 21,
+    "OxM": 22
+}
+
+IUPAC_VOCAB = {
+    "_": 0,
+    "<mask>": 1,
+    "<cls>": 2,
+    "<sep>": 3,
+    "<unk>": 4,
+    "A": 5,
+    "B": 6,
+    "C": 7,
+    "D": 8,
+    "E": 9,
+    "F": 10,
+    "G": 11,
+    "H": 12,
+    "I": 13,
+    "K": 14,
+    "L": 15,
+    "M": 16,
+    "N": 17,
+    "O": 18,
+    "P": 19,
+    "Q": 20,
+    "R": 21,
+    "S": 22,
+    "T": 23,
+    "U": 24,
+    "V": 25,
+    "W": 26,
+    "X": 27,
+    "Y": 28,
+    "Z": 29}
+
+ALPHABET_UNMOD_REV = {v: k for k, v in ALPHABET_UNMOD.items()}
+
+
+def padding(dataframe, columns, length):
+    def pad(x):
+        return x + (length - len(x) + 2 * x.count('-')) * '_'
+
+    for i in range(len(dataframe)):
+        if len(dataframe[columns][i]) > length + 2 * dataframe[columns][i].count('-'):
+            dataframe.drop(i)
+    dataframe[columns] = dataframe[columns].map(pad)
+    for i in range(len(dataframe)):
+        if len(dataframe[columns][i]) > length:
+            dataframe.drop(i)
+
+
+def alphabetical_to_numerical(seq, vocab):
+    num = []
+    dec = 0
+    if vocab == 'unmod':
+        for i in range(len(seq) - 2 * seq.count('-')):
+            if seq[i + dec] != '-':
+                num.append(ALPHABET_UNMOD[seq[i + dec]])
+            else:
+                if seq[i + dec + 1:i + dec + 4] == 'CaC':
+                    num.append(21)
+                elif seq[i + dec + 1:i + dec + 4] == 'OxM':
+                    num.append(22)
+                else:
+                    raise 'Modification not supported'
+                dec += 4
+    else :
+        for i in range(len(seq) - 2 * seq.count('-')):
+            if seq[i + dec] != '-':
+                num.append(IUPAC_VOCAB[seq[i + dec]])
+            else:
+                if seq[i + dec + 1:i + dec + 4] == 'CaC':
+                    num.append(21)
+                elif seq[i + dec + 1:i + dec + 4] == 'OxM':
+                    num.append(22)
+                else:
+                    raise 'Modification not supported'
+                dec += 4
+    return np.array(num)
+
+def numerical_to_alphabetical(arr):
+    seq = ''
+    for i in range(len(arr)):
+        seq+=ALPHABET_UNMOD_REV[arr[i]]
+    return seq
+
+def zero_to_minus(arr):
+    arr[arr <= 0.00001] = -1.
+    return arr
+
+
+class Common_Dataset(Dataset):
+
+    def __init__(self, dataframe, length, pad=True, convert=True, vocab='unmod'):
+        print('Data loader Initialisation')
+        self.data = dataframe.reset_index()
+        if pad :
+            print('Padding')
+            padding(self.data, 'Sequence', length)
+
+        if convert :
+            print('Converting')
+            self.data['Sequence'] = self.data['Sequence'].map(lambda x: alphabetical_to_numerical(x, vocab))
+            self.data['Spectra'] = self.data['Spectra'].map(zero_to_minus)
+
+    def __getitem__(self, index: int):
+        seq = self.data['Sequence'][index]
+        rt = self.data['Retention time'][index]
+        intensity = self.data['Spectra'][index]
+
+        charge = self.data['Charge'][index]
+
+        return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity)
+
+    def __len__(self) -> int:
+        return self.data.shape[0]
+
+
+def load_data(path_train, path_val, path_test, batch_size, length, pad=False, convert=False, vocab = 'unmod'):
+    print('Loading data')
+    data_train = pd.read_pickle(path_train)
+    data_val = pd.read_pickle(path_val)
+    data_test = pd.read_pickle(path_test)
+    train = Common_Dataset(data_train, length, pad, convert, vocab)
+    test = Common_Dataset(data_val, length, pad, convert, vocab)
+    val = Common_Dataset(data_test, length, pad, convert, vocab)
+    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
+    test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
+    val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)
+
+    return train_loader, val_loader, test_loader
diff --git a/config.py b/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e41c60fe85da5db0c8e71eb880635a34c6447ff7
--- /dev/null
+++ b/config.py
@@ -0,0 +1,23 @@
+import argparse
+
+
+def load_args():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('--epochs', type=int, default=100)
+    parser.add_argument('--save_inter', type=int, default=100)
+    parser.add_argument('--eval_inter', type=int, default=1)
+    parser.add_argument('--lr', type=float, default=0.001)
+    parser.add_argument('--batch_size', type=int, default=1024)
+    parser.add_argument('--n_test', type=int, default=None)
+    parser.add_argument('--n_train', type=int, default=None)
+    parser.add_argument('--n_head', type=int, default=1)
+    parser.add_argument('--model', type=str, default='RT_multi_sum')
+    parser.add_argument('--wandb', type=str, default=None)
+    parser.add_argument('--coef_pretext', type=float, default=1.)
+    parser.add_argument('--dataset_train', type=str, default='database/data.csv')
+    parser.add_argument('--dataset_test', type=str, default='database/data.csv')
+    parser.add_argument('--layers_sizes', nargs='+', type=int, default=[256, 512, 512])
+    args = parser.parse_args()
+
+    return args
diff --git a/config_common.py b/config_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bf9ab098f7c5f93fb7060c025066f5dc675c1a3
--- /dev/null
+++ b/config_common.py
@@ -0,0 +1,30 @@
+import argparse
+
+
+def load_args():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('--epochs', type=int, default=100)
+    parser.add_argument('--save_inter', type=int, default=100)
+    parser.add_argument('--eval_inter', type=int, default=1)
+    parser.add_argument('--lr', type=float, default=0.001)
+    parser.add_argument('--batch_size', type=int, default=2048)
+    parser.add_argument('--n_head', type=int, default=1)
+    parser.add_argument('--embedding_dim', type=int, default=16)
+    parser.add_argument('--encoder_ff', type=int, default=2048)
+    parser.add_argument('--decoder_rt_ff', type=int, default=2048)
+    parser.add_argument('--decoder_int_ff', type=int, default=512)
+    parser.add_argument('--encoder_num_layer', type=int, default=2)
+    parser.add_argument('--decoder_rt_num_layer', type=int, default=1)
+    parser.add_argument('--decoder_int_num_layer', type=int, default=1)
+    parser.add_argument('--drop_rate', type=float, default=0.035)
+    parser.add_argument('--wandb', type=str, default=None)
+    parser.add_argument('--forward', type=str, default='both')
+    parser.add_argument('--dataset_train', type=str, default='database/data_DIA_ISA_55_train.pkl')
+    parser.add_argument('--dataset_val', type=str, default='database/data_DIA_ISA_55_test.pkl')
+    parser.add_argument('--dataset_test', type=str, default='database/data_DIA_ISA_55_test.pkl')
+    parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction)
+    parser.add_argument('--activation', type=str,default='relu')
+    args = parser.parse_args()
+
+    return args
diff --git a/data_exploration.py b/data_exploration.py
new file mode 100644
index 0000000000000000000000000000000000000000..019c47127d7b437253c4af8ecf9726a9bc887f62
--- /dev/null
+++ b/data_exploration.py
@@ -0,0 +1,287 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+import pandas as pd
+
+matplotlib.use('agg')
+length = 30
+
+
+ALPHABET_UNMOD = {
+        "A": 0,
+        "C": 1,
+        "D": 2,
+        "E": 3,
+        "F": 4,
+        "G": 5,
+        "H": 6,
+        "I": 7,
+        "K": 8,
+        "L": 9,
+        "M": 10,
+        "N": 11,
+        "P": 12,
+        "Q": 13,
+        "R": 14,
+        "S": 15,
+        "T": 16,
+        "V": 17,
+        "W": 18,
+        "Y": 19,
+        "CaC": 20,
+        "OxM": 21
+    }
+def separe_by_length(X, Y, base_name='out'):
+    max = 30
+    datasets = [[[], []] for i in range(max)]
+    for i in range(X.shape[0]):
+        try:
+            datasets[list(X[i]).index(0)-X[i].count('-')*2][0].append(X[i])
+            datasets[list(X[i]).index(0)-X[i].count('-')*2][1].append(Y[i])
+        except:
+            datasets[max - 1][0].append(X[i][:])
+            datasets[max - 1][1].append(Y[i])
+
+    for length in range(max):
+        print('Cutting ', length)
+        if datasets[length][0]:
+            print('data/X_' + base_name + '_' + str(length) + '.npy')
+            X_cut = np.array(datasets[length][0])
+            Y_cut = np.array(datasets[length][1])
+            np.save('data/RT/X_' + base_name + '_' + str(length) + '.npy', X_cut)
+            np.save('data/RT/Y_' + base_name + '_' + str(length) + '.npy', Y_cut)
+
+
+def dist_long(X, plot=False, save=False, f_name='out.png'):
+    max = 31
+    dist = np.zeros(max)
+    for seq in X:
+        try:
+            dist[len(list(seq)) - seq.count('-') * 2] += 1
+        except:
+            dist[-1] += 1
+
+    if plot or save:
+
+        plt.stairs(dist, range(max + 1), fill=True)
+        if plot:
+            plt.show()
+        if save:
+            plt.savefig(f_name)
+        plt.clf()
+        plt.close()
+    return 100 * dist / X.shape[0]
+
+
+def feq_aa(X, plot=False, save=False, f_name='out.png'):
+    freq = np.zeros(22)
+    for seq in X:
+        dec = 0
+        for i in range(len(seq)-2*seq.count('-')):
+            if seq[dec+i] != '-':
+                freq[ALPHABET_UNMOD[seq[dec+i]]] += 1
+            elif seq[i+dec+1:i+dec+4] == 'CaC':
+                dec += 4
+                freq[ALPHABET_UNMOD['CaC']] += 1
+            elif seq[i+dec+1:i+dec+4] == 'OxM':
+                dec += 4
+                freq[ALPHABET_UNMOD['OxM']] += 1
+
+    freq = 100 * freq / freq.sum()
+
+    dict_freq = ALPHABET_UNMOD.copy()
+    for aa in list(ALPHABET_UNMOD.keys()):
+        dict_freq[aa] = freq[ALPHABET_UNMOD[aa]]
+
+    if plot or save:
+        plt.bar(list(ALPHABET_UNMOD.keys()), freq, label=list(ALPHABET_UNMOD.keys()))
+        if plot:
+            plt.show()
+        if save:
+            plt.savefig(f_name)
+        plt.clf()
+        plt.close()
+    return dict_freq
+
+
+def intersection(A, B):
+
+    C = np.intersect1d(A, B)
+
+    return C.shape[0], C
+
+
+
+
+
+def RT_variance(X, Y):
+    unique_rows, unique_indices, unique_counts = np.unique(X, axis=0, return_inverse=True, return_counts=True)
+    variances = np.zeros(len(unique_rows))
+
+    for i in range(len(unique_rows)):
+        ind = np.where(unique_indices == i)
+        variances[i] = np.var(Y[ind])
+
+    return np.mean(variances)
+
+def RT_distrib(Y, f_name):
+    plt.hist(Y,bins = 50)
+    plt.title("RT distribution")
+    plt.savefig(f_name)
+    plt.clf()
+
+# data = pd.read_csv('database/data_ptms.csv')
+# data_train = data[data.state == 'train']
+# data_train_2 = data_train.drop([data_train.columns[0] ,'sequence','irt_scaled','state'], axis = 1)
+# data_train_2.to_csv('data/RT/data_ptms_train.csv', index= False)
+# data_test = data[data.state == 'holdout']
+# data_validation = data[data.state == 'validation']
+
+
+# mean_unique_train = data_train.groupby(['mod_sequence'])['irt'].mean()
+# var_unique_train = data_train.groupby(['mod_sequence'])['irt'].var()
+# avg_train = pd.concat([mean_unique_train,var_unique_train], axis=1).reset_index()
+# avg_train['state'] = 'train'
+#
+# mean_unique_test = data_test.groupby(['mod_sequence'])['irt'].mean()
+# var_unique_test = data_test.groupby(['mod_sequence'])['irt'].var()
+# avg_test = pd.concat([mean_unique_test,var_unique_test], axis=1).reset_index()
+# avg_test['state'] = 'holdout'
+#
+# mean_unique_validation = data_validation.groupby(['mod_sequence'])['irt'].mean()
+# var_unique_validation = data_validation.groupby(['mod_sequence'])['irt'].var()
+# avg_validation = pd.concat([mean_unique_validation, var_unique_validation], axis=1).reset_index()
+# avg_validation['state'] = 'validation'
+#
+# avg = pd.concat([avg_train, avg_test, avg_validation])
+#
+# avg.columns.values[0] = 'mod_sequence'
+# avg.columns.values[1] = 'irt'
+# avg.columns.values[2] = 'var'
+# avg = avg.fillna(0)
+# data_unique = avg.query('var <= 1')
+#
+# avg.to_csv('database/data_unique.csv', index=False)
+#
+# data_unique = pd.read_csv('database/data_unique.csv')
+#
+# data_train = data_unique[data_unique.state == 'train']
+# data_test= data_unique[data_unique.state == 'holdout']
+# data_validation = data_unique[data_unique.state == 'validation']
+#
+#
+# np.save('data/RT/Y_unique_train.npy',data_train['irt'],allow_pickle=True)
+# np.save('data/RT/Y_unique_holdout.npy',data_test['irt'],allow_pickle=True)
+# np.save('data/RT/Y_unique_validation.npy',data_validation['irt'],allow_pickle=True)
+# np.save('data/RT/X_unique_train.npy',data_train['mod_sequence'],allow_pickle=True)
+# np.save('data/RT/X_unique_holdout.npy',data_test['mod_sequence'],allow_pickle=True)
+# np.save('data/RT/X_unique_validation.npy',data_validation['mod_sequence'],allow_pickle=True)
+# #
+#
+# X_train = np.load('data/RT/X_unique_train.npy',allow_pickle=True)
+# Y_train = np.load('data/RT/Y_unique_train.npy',allow_pickle=True)
+# X_validation = np.load('data/RT/X_unique_validation.npy',allow_pickle=True)
+# Y_validation = np.load('data/RT/Y_unique_validation.npy',allow_pickle=True)
+# X_test = np.load('data/RT/X_unique_holdout.npy',allow_pickle=True)
+# Y_test = np.load('data/RT/Y_unique_holdout.npy',allow_pickle=True)
+#
+# X_train = np.array(X_train.tolist())
+# X_validation = np.array(X_validation.tolist())
+# X_test = np.array(X_test.tolist())
+#
+#
+# print('\n Tailles des données')
+# print('\n Train : ', Y_train.size)
+# print('Validation : ', Y_validation.size)
+# print('Test: ', Y_test.size)
+#
+# print('\n Longueurs des séquences')
+# print('\n Train : ', dist_long(X_train, plot=False, save=True, f_name='fig/histo_length_train_unique.png'))
+# print('Validation : ', dist_long(X_validation, plot=False, save=True, f_name='fig/histo_length_validation_unique.png'))
+# print('Test : ', dist_long(X_test, plot=False, save=True, f_name='fig/histo_length_test_unique.png'))
+#
+# print('\n Fréquences des acides aminés')
+# print('\n Train : ', feq_aa(X_train, plot=False, save=True, f_name='fig/histo_aa_train_unique.png'))
+# print('Validation : ', feq_aa(X_validation, plot=False, save=True, f_name='fig/histo_aa_validation_unique.png'))
+# print('Test : ', feq_aa(X_test, plot=False, save=True, f_name='fig/histo_aa_test_unique.png'))
+#
+#
+# l1, c1 = intersection(X_train, X_validation)
+# l2, c2 = intersection(X_train, X_test)
+# l3, c3 = intersection(X_test, X_validation)
+#
+# print('\n Intersection checking')
+# print('\n Train x Validation : ', l1, ' intersection(s)')
+# print('Train x Test : ', l2, ' intersection(s)')
+# print('Test x validation : ', l3, ' intersection(s)')
+#
+# # print('\n RT Variance')
+# # print('\n Train : ', RT_variance(X_train, Y_train))
+# # print('\n Validation : ', RT_variance(X_validation, Y_validation))
+# # print('\n Test : ', RT_variance(X_test, Y_test))
+#
+# RT_distrib(Y_train,'fig/histo_RT_train_unique.png' )
+# RT_distrib(Y_test,'fig/histo_RT_test_unique.png' )
+# RT_distrib(Y_validation,'fig/histo_RT_validation_unique.png' )
+#
+# data_train = data[data.state == 'train']
+# data_test= data[data.state == 'holdout']
+# data_validation = data[data.state == 'validation']
+#
+#
+# np.save('data/RT/Y_train.npy',data_train['irt'],allow_pickle=True)
+# np.save('data/RT/Y_holdout.npy',data_test['irt'],allow_pickle=True)
+# np.save('data/RT/Y_validation.npy',data_validation['irt'],allow_pickle=True)
+# np.save('data/RT/X_train.npy',data_train['mod_sequence'],allow_pickle=True)
+# np.save('data/RT/X_holdout.npy',data_test['mod_sequence'],allow_pickle=True)
+# np.save('data/RT/X_validation.npy',data_validation['mod_sequence'],allow_pickle=True)
+# #
+#
+# X_train = np.load('data/RT/X_train.npy',allow_pickle=True)
+# Y_train = np.load('data/RT/Y_train.npy',allow_pickle=True)
+# X_validation = np.load('data/RT/X_validation.npy',allow_pickle=True)
+# Y_validation = np.load('data/RT/Y_validation.npy',allow_pickle=True)
+# X_test = np.load('data/RT/X_holdout.npy',allow_pickle=True)
+# Y_test = np.load('data/RT/Y_holdout.npy',allow_pickle=True)
+#
+# X_train = np.array(X_train.tolist())
+# X_validation = np.array(X_validation.tolist())
+# X_test = np.array(X_test.tolist())
+#
+#
+# print('\n Tailles des données')
+# print('\n Train : ', Y_train.size)
+# print('Validation : ', Y_validation.size)
+# print('Test: ', Y_test.size)
+#
+# print('\n Longueurs des séquences')
+# print('\n Train : ', dist_long(X_train, plot=False, save=True, f_name='fig/histo_length_train.png'))
+# print('Validation : ', dist_long(X_validation, plot=False, save=True, f_name='fig/histo_length_validation.png'))
+# print('Test : ', dist_long(X_test, plot=False, save=True, f_name='fig/histo_length_test.png'))
+#
+# print('\n Fréquences des acides aminés')
+# print('\n Train : ', feq_aa(X_train, plot=False, save=True, f_name='fig/histo_aa_train.png'))
+# print('Validation : ', feq_aa(X_validation, plot=False, save=True, f_name='fig/histo_aa_validation.png'))
+# print('Test : ', feq_aa(X_test, plot=False, save=True, f_name='fig/histo_aa_test.png'))
+#
+#
+#
+# l1, c1 = intersection(X_train, X_validation)
+# l2, c2 = intersection(X_train, X_test)
+# l3, c3 = intersection(X_test, X_validation)
+#
+# print('\n Intersection checking')
+# print('\n Train x Validation : ', l1, ' intersection(s)')
+# print('Train x Test : ', l2, ' intersection(s)')
+# print('Test x validation : ', l3, ' intersection(s)')
+#
+# print('\n RT Variance')
+# print('\n Train : ', RT_variance(X_train, Y_train))
+# print('\n Validation : ', RT_variance(X_validation, Y_validation))
+# print('\n Test : ', RT_variance(X_test, Y_test))
+#
+# RT_distrib(Y_train,'fig/histo_RT_train.png' )
+# RT_distrib(Y_test,'fig/histo_RT_test.png' )
+# RT_distrib(Y_validation,'fig/histo_RT_validation.png' )
+#
+#
diff --git a/data_viz.py b/data_viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2764b8055e3d620d35087131eb3e7ffb42cede7
--- /dev/null
+++ b/data_viz.py
@@ -0,0 +1,113 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import random
+from mass_prediction import compute_frag_mz_ration
+
+seq = 'YEEEFLR'
+def data(a):
+    b=a+a
+    return b
+
+int = np.random.rand(174)
+
+names = ['b1(+)', 'y1(+)', 'b1(2+)', 'y1(2+)', 'b1(3+)', 'y1(3+)','b2(+)', 'y2(+)', 'b2(2+)', 'y2(2+)', 'b2(3+)', 'y2(3+)',
+         'b3(+)', 'y3(+)', 'b3(2+)', 'y3(2+)', 'b3(3+)', 'y3(3+)', 'b4(+)', 'y4(+)', 'b4(2+)', 'y4(2+)', 'b4(3+)',
+         'y4(3+)','b5(+)', 'y5(+)', 'b5(2+)', 'y5(2+)', 'b5(3+)', 'y5(3+)','b6(+)', 'y6(+)', 'b6(2+)', 'y6(2+)',
+         'b6(3+)', 'y6(3+)','b7(+)', 'y7(+)', 'b7(2+)', 'y7(2+)', 'b7(3+)', 'y7(3+)','b8(+)', 'y8(+)', 'b8(2+)',
+         'y8(2+)', 'b8(3+)', 'y8(3+)','b9(+)', 'y9(+)', 'b9(2+)', 'y9(2+)', 'b9(3+)', 'y9(3+)','b10(+)', 'y10(+)',
+         'b10(2+)', 'y10(2+)', 'b10(3+)', 'y10(3+)','b11(+)', 'y11(+)', 'b11(2+)', 'y11(2+)', 'b11(3+)', 'y11(3+)',
+         'b12(+)', 'y12(+)', 'b12(2+)', 'y12(2+)', 'b12(3+)', 'y12(3+)', 'b13(+)', 'y13(+)', 'b13(2+)', 'y13(2+)',
+         'b13(3+)', 'y13(3+)','b14(+)', 'y14(+)', 'b14(2+)', 'y14(2+)', 'b14(3+)', 'y14(3+)','b15(+)', 'y15(+)',
+         'b15(2+)', 'y15(2+)', 'b15(3+)', 'y15(3+)', 'b16(+)', 'y16(+)', 'b16(2+)', 'y16(2+)', 'b16(3+)', 'y16(3+)',
+         'b17(+)', 'y17(+)', 'b17(2+)', 'y17(2+)', 'b17(3+)', 'y17(3+)','b18(+)', 'y18(+)', 'b18(2+)', 'y18(2+)',
+         'b18(3+)', 'y18(3+)','b19(+)', 'y19(+)', 'b19(2+)', 'y19(2+)', 'b19(3+)', 'y19(3+)','b20(+)', 'y20(+)',
+         'b20(2+)', 'y20(2+)', 'b20(3+)', 'y20(3+)','b21(+)', 'y21(+)', 'b21(2+)', 'y21(2+)', 'b21(3+)', 'y21(3+)',
+         'b22(+)', 'y22(+)', 'b22(2+)', 'y22(2+)', 'b22(3+)', 'y22(3+)','b23(+)', 'y23(+)', 'b23(2+)', 'y23(2+)',
+         'b23(3+)', 'y23(3+)','b24(+)', 'y24(+)', 'b24(2+)', 'y24(2+)', 'b24(3+)', 'y24(3+)','b25(+)', 'y25(+)',
+         'b25(2+)', 'y25(2+)', 'b25(3+)', 'y25(3+)','b26(+)', 'y26(+)', 'b26(2+)', 'y26(2+)', 'b26(3+)', 'y26(3+)',
+         'b27(+)', 'y27(+)', 'b27(2+)', 'y27(2+)', 'b27(3+)', 'y27(3+)','b28(+)', 'y28(+)', 'b28(2+)', 'y28(2+)',
+         'b28(3+)', 'y28(3+)','b29(+)', 'y29(+)', 'b29(2+)', 'y29(2+)', 'b29(3+)', 'y29(3+)']
+
+names = np.array(names)
+
+def frag_spectra(int, seq):
+    masses = compute_frag_mz_ration(seq,'mono')
+    msk = [el!=-1. for el in int]
+    # Choose some nice levels
+    levels = int[msk]
+    dates = masses[msk]
+    # Create figure and plot a stem plot with the date
+    fig, ax = plt.subplots(figsize=(8.8, 4), constrained_layout=True)
+    ax.set(title=seq + " fragmentation spectra")
+
+    ax.vlines(dates, 0, levels, color="tab:red")  # The vertical stems.
+    ax.plot(dates, np.zeros_like(dates),
+            color="k", markerfacecolor="w")  # Baseline and markers on it.
+
+    # annotate lines
+    for d, l, r in zip(dates, levels, names):
+        ax.annotate(r, xy=(d, l),
+                    xytext=(-3, np.sign(l) * 3), textcoords="offset points",
+                    horizontalalignment="right",
+                    verticalalignment="bottom" if l > 0 else "top")
+
+
+    plt.setp(ax.get_xticklabels(), rotation=30, ha="right")
+
+    # remove y axis and spines
+    ax.yaxis.set_visible(False)
+    ax.spines[["left", "top", "right"]].set_visible(False)
+
+    ax.margins(y=0.1)
+    plt.show()
+
+def frag_spectra_comparison(int_1, seq_1, int_2, seq_2=None):
+    if seq_2 is None :
+        seq_2 = seq_1
+    masses_1 = compute_frag_mz_ration(seq_1,'mono')
+    msk_1 = [el!=-1 for el in int_1]
+    levels_1 = int_1[msk_1]
+    dates_1 = masses_1[msk_1]
+    names_1 = names[msk_1]
+    masses_2 = compute_frag_mz_ration(seq_2, 'mono')
+    msk_2 = [el != -1. for el in int_2]
+    levels_2 = int_2[msk_2]
+    dates_2 = masses_2[msk_2]
+    names_2 = names[msk_2]
+    # Create figure and plot a stem plot with the date
+    fig, ax = plt.subplots(figsize=(8.8, 4), constrained_layout=True)
+    ax.set(title=seq_1 + " / " +seq_2 + " fragmentation spectra comparison")
+
+    ax.vlines(dates_1, 0, levels_1, color="tab:red")  # The vertical stems.
+    ax.plot(dates_1, np.zeros_like(dates_1),
+            color="k", markerfacecolor="w")  # Baseline and markers on it.
+
+    # annotate lines
+    for d, l, r in zip(dates_1, levels_1, names_1):
+        ax.annotate(r, xy=(d, l),
+                    xytext=(-3, np.sign(l) * 3), textcoords="offset points",
+                    horizontalalignment="right",
+                    verticalalignment="bottom" if l > 0 else "top")
+
+    ax.vlines(dates_2, 0, -levels_2, color="tab:blue")  # The vertical stems.
+    ax.plot(dates_2, np.zeros_like(dates_2),
+            color="k", markerfacecolor="w")  # Baseline and markers on it.
+
+    # annotate lines
+    for d, l, r in zip(dates_2, -levels_2, names_2):
+        ax.annotate(r, xy=(d, l),
+                    xytext=(-3, np.sign(l) * 3), textcoords="offset points",
+                    horizontalalignment="right",
+                    verticalalignment="bottom" if l > 0 else "top")
+
+
+
+
+    plt.setp(ax.get_xticklabels(), rotation=30, ha="right")
+
+    # remove y axis and spines
+    ax.yaxis.set_visible(False)
+    ax.spines[["left", "top", "right"]].set_visible(False)
+
+    ax.margins(y=0.1)
+    plt.show()
\ No newline at end of file
diff --git a/dataloader.py b/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3b872e24d88f3ad13b482149bbe2d9ed40934fe
--- /dev/null
+++ b/dataloader.py
@@ -0,0 +1,203 @@
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader
+import pandas as pd
+
+ALPHABET_UNMOD = {
+    "_": 0,
+    "A": 1,
+    "C": 2,
+    "D": 3,
+    "E": 4,
+    "F": 5,
+    "G": 6,
+    "H": 7,
+    "I": 8,
+    "K": 9,
+    "L": 10,
+    "M": 11,
+    "N": 12,
+    "P": 13,
+    "Q": 14,
+    "R": 15,
+    "S": 16,
+    "T": 17,
+    "V": 18,
+    "W": 19,
+    "Y": 20,
+    "CaC": 21,
+    "OxM": 22
+}
+
+
+def padding(dataframe, columns, length):
+    def pad(x):
+        return x + (length - len(x) + 2 * x.count('-')) * '_'
+
+    for i in range(len(dataframe)):
+        if len(dataframe[columns][i]) > length + 2 * dataframe[columns][i].count('-'):
+            dataframe.drop(i)
+    dataframe[columns] = dataframe[columns].map(pad)
+    for i in range(len(dataframe)):
+        if len(dataframe[columns][i]) > length:
+            dataframe.drop(i)
+
+
+def alphabetical_to_numerical(seq):
+    num = []
+    dec = 0
+    for i in range(len(seq) - 2 * seq.count('-')):
+        if seq[i + dec] != '-':
+            num.append(ALPHABET_UNMOD[seq[i + dec]])
+        else:
+            if seq[i + dec + 1:i + dec + 4] == 'CaC':
+                num.append(21)
+            elif seq[i + dec + 1:i + dec + 4] == 'OxM':
+                num.append(22)
+            else:
+                raise 'Modification not supported'
+            dec += 4
+    return num
+
+
+class RT_Dataset(Dataset):
+
+    def __init__(self, size, data_source, mode, length, format='iRT'):
+        print('Data loader Initialisation')
+        self.data = pd.read_csv(data_source)
+
+        self.mode = mode
+        self.format = format
+
+        print('Selecting data')
+        if mode == 'train':
+            self.data = self.data[self.data.state == 'train']
+        elif mode == 'test':
+            self.data = self.data[self.data.state == 'holdout']
+        elif mode == 'validation':
+            self.data = self.data[self.data.state == 'validation']
+        if size is not None:
+            self.data = self.data.sample(size)
+
+        print('Padding')
+        self.data['sequence'] = self.data['sequence'].str.pad(length, side='right', fillchar='_')
+        self.data = self.data.drop(self.data[self.data['sequence'].map(len) > length].index)
+
+        print('Converting')
+        self.data['sequence'] = self.data['sequence'].map(alphabetical_to_numerical)
+
+        self.data = self.data.reset_index()
+
+    def __getitem__(self, index: int):
+        seq = self.data['sequence'][index]
+        if self.format == 'RT':
+            label = self.data['retention_time'][index]
+        if self.format == 'iRT':
+            label = self.data['irt'][index]
+        if self.format == 'iRT_scaled':
+            label = self.data['iRT_scaled'][index]
+        if self.format == 'score':
+            label = self.data['score'][index]
+        return torch.tensor(seq), torch.tensor(label).float()
+
+    def __len__(self) -> int:
+        return self.data.shape[0]
+
+
+def load_data(batch_size, data_sources, n_train=None, n_test=None, length=30):
+    print('Loading data')
+    train = RT_Dataset(n_train, data_sources[0], 'train', length)
+    test = RT_Dataset(n_test, data_sources[1], 'test', length)
+    val = RT_Dataset(n_test, data_sources[2], 'validation', length)
+    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
+    test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
+    val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)
+
+    return train_loader, val_loader, test_loader
+
+
+class H5ToStorage():
+    def __init__(self, hdf_path):
+        self.path = hdf_path
+
+        self.classes = []
+        with h5py.File(hdf_path, 'r') as hf:
+            for class_ in hf:
+                self.classes.append(class_)
+
+    def get_class(self):
+        return self.classes
+
+    def make_npy_file(self, f_name, column):
+        with h5py.File(self.path, 'r') as hf:
+            data = hf[column]
+            np.save(f_name, data)
+
+
+def load_split_intensity(sources, batch_size, split=(0.5, 0.25, 0.25)):
+    assert sum(split) == 1, 'Wrong split argument'
+    seq = np.load(sources[0])
+    intensity = np.load(sources[1])
+    energy = np.load(sources[2])
+    precursor_charge = np.load(sources[3])
+
+    len = np.shape(energy)[0]
+    ind1 = int(np.floor(len * split[0]))
+    ind2 = int(np.floor(len * (split[0] + split[1])))
+    train = (seq[:ind1], intensity[:ind1], energy[:ind1], precursor_charge[:ind1])
+    validation = (
+        seq[ind1:ind2], intensity[ind1:ind2], energy[ind1:ind2], precursor_charge[ind1:ind2])
+    test = (seq[ind2:], intensity[ind2:], energy[ind2:], precursor_charge[ind2:])
+
+    train = Intentsity_Dataset(train)
+    test = Intentsity_Dataset(test)
+    validation = Intentsity_Dataset(validation)
+    train = DataLoader(train, batch_size=batch_size)
+    test = DataLoader(test, batch_size=batch_size)
+    validation = DataLoader(validation, batch_size=batch_size)
+
+    return train, validation, test
+
+
+def load_intensity_from_files(f_seq, f_intentsity, f_energy, f_percursor_charge, batch_size):
+    seq = np.load(f_seq, )
+    intensity = np.load(f_intentsity)
+    energy = np.load(f_energy)
+    precursor_charge = np.load(f_percursor_charge)
+    data = (seq, intensity, energy, precursor_charge)
+    dataset = Intentsity_Dataset(data)
+    loader = DataLoader(dataset, batch_size=batch_size)
+    return loader
+
+def load_intensity_df_from_files(f_seq, f_intentsity, f_energy, f_percursor_charge):
+    seq = np.load(f_seq, )
+    intensity = np.load(f_intentsity)
+    energy = np.load(f_energy)
+    precursor_charge = np.load(f_percursor_charge)
+    data = (seq, intensity, energy, precursor_charge)
+    dataset = Intentsity_Dataset(data)
+    return dataset
+
+class Intentsity_Dataset(Dataset):
+
+    def __init__(self, data):
+        self.data = data
+        self.seq = data[0]
+        self.intensity = data[1]
+        self.energy = data[2]
+        self.precursor_charge = data[3]
+
+    def __len__(self):
+        return len(self.seq)
+
+    def __getitem__(self, idx):
+        return torch.tensor(self.seq[idx]), torch.tensor([self.energy[idx]]).float(), torch.tensor(
+            self.precursor_charge[idx]), torch.tensor(self.intensity[idx]).float()
+
+# storage = H5ToStorage('database/holdout_hcd.hdf5')
+# storage.make_npy_file('data/intensity/method.npy','method')
+# storage.make_npy_file('data/intensity/sequence_header.npy','sequence_integer')
+# storage.make_npy_file('data/intensity/intensity_header.npy', 'intensities_raw')
+# storage.make_npy_file('data/intensity/collision_energy_header.npy', 'collision_energy_aligned_normed')
+# storage.make_npy_file('data/intensity/precursor_charge_header.npy', 'precursor_charge_onehot')
diff --git a/decoy.py b/decoy.py
new file mode 100644
index 0000000000000000000000000000000000000000..53110e1e1cd1192e3cbacdf1b50fefa5f3fbf79a
--- /dev/null
+++ b/decoy.py
@@ -0,0 +1,90 @@
+import torch.distributions as dist
+import random
+import numpy as np
+
+ALPHABET_UNMOD = {
+    "A": 1,
+    "C": 2,
+    "D": 3,
+    "E": 4,
+    "F": 5,
+    "G": 6,
+    "H": 7,
+    "I": 8,
+    "K": 9,
+    "L": 10,
+    "M": 11,
+    "N": 12,
+    "P": 13,
+    "Q": 14,
+    "R": 15,
+    "S": 16,
+    "T": 17,
+    "V": 18,
+    "W": 19,
+    "Y": 20,
+}
+
+def reverse_protein(seq):
+    return seq[::-1]
+
+def reverse_peptide(seq, format='numerical'):
+    if format == 'numerical' :
+        for i in range(len(seq)//2):
+            if seq[i]!=9 and  seq[i]!=15 :
+                mem = seq[i]
+                seq[i] = seq[-i]
+                seq[-i] = mem
+    if format=='alphabetical':
+        for i in range(len(seq)//2):
+            if seq[i]!='K' and  seq[i]!='R':
+                mem = seq[i]
+                seq[i] = seq[-i]
+                seq[-i] = mem
+    return seq
+
+def shuffle_protein(seq):
+    c = seq.copy()
+    random.shuffle(c)
+    return c
+
+def shuffle_peptide(seq, format='numerical'): #TODO A reparer
+    if format  == 'numerical':
+        ind = np.where(seq == 9 or seq == 15, seq)
+        print(ind)
+        final_seq = seq.copy()
+        random.shuffle(final_seq)
+        del final_seq[ind]
+    if format == 'alphabetical' :
+        ind = np.where(seq == 'R' or seq == 'K', seq)
+        print(ind)
+        final_seq = seq.copy()
+        del final_seq[ind]
+        random.shuffle(final_seq)
+    for i in range(len(ind)):
+        final_seq.insert(ind[i]+i,seq[i])
+    return final_seq
+def random_aa(database, format='numerical'):
+    total_seq = database.unroll()
+    freq = total_seq.count()
+    freq.normalize()
+    d = dist.Categorical(freq)
+    l = len(total_seq)
+    new_seq = d.sample(l)
+    #similarcutting
+
+
+def random_aa_trypsin(database, format='numerical'):
+    total_seq = database.unroll()
+    if format=='numerical' :
+        total_seq.remove(9)
+        total_seq.remove(15)
+    if format=='alphabetical' :
+        total_seq.remove('R')
+        total_seq.remove('K')
+    freq = total_seq.count()
+    freq.normalize()
+    d = dist.Categorical(freq)
+    l = len(total_seq)
+    new_seq = d.sample(l)
+    #similarcutting
diff --git a/key_peptide_distribution.py b/key_peptide_distribution.py
new file mode 100644
index 0000000000000000000000000000000000000000..048cf1a176473beaf9804d337d9a25706d5a7467
--- /dev/null
+++ b/key_peptide_distribution.py
@@ -0,0 +1,29 @@
+import pandas as pd
+from data_exploration import dist_long, feq_aa
+
+
+# Enterobac = pd.read_excel('database/Enterobac.xlsx')
+# Enterococcus = pd.read_excel('database/Enterococcus.xlsx')
+# Pyo = pd.read_excel('database/Pyo.xlsx')
+# Staph = pd.read_excel('database/Staph.xlsx')
+#
+# Enterobac['id']='Enterobac'
+# Enterococcus['id']='Enterococcus'
+# Pyo['id']='Pyo'
+# Staph['id']='Staph'
+#
+# res = pd.concat([Enterobac, Enterococcus, Pyo, Staph], ignore_index=True)
+#
+# id = pd.read_excel('database/All_Peptides.xlsx')
+#
+# res.to_csv('database/peptides_res.csv')
+# id.to_csv('database/peptides_id.csv')
+
+res = pd.read_csv('database/peptides_res.csv')
+id = pd.read_csv('database/peptides_id.csv')
+
+dist_res = dist_long(res['Peptides'], plot=False, save=True, f_name='fig/res_dist_long.png')
+dist_id = dist_long(id['Peptides'], plot=False, save=True, f_name='fig/id_dist_long.png')
+
+freq_res = feq_aa(res['Peptides'], plot=False, save=True, f_name='fig/res_feq_aa.png')
+freq_id = feq_aa(id['Peptides'], plot=False, save=True, f_name='fig/id_feq_aa.png')
\ No newline at end of file
diff --git a/layers.py b/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1ca54adb401e9084686d4498b353eaffaaa2551
--- /dev/null
+++ b/layers.py
@@ -0,0 +1,193 @@
+import math
+
+import torch
+from torch import nn
+
+
+class SelectItem(nn.Module):
+    def __init__(self, item_index):
+        super(SelectItem, self).__init__()
+        self._name = 'selectitem'
+        self.item_index = item_index
+
+    def forward(self, inputs):
+        return inputs[self.item_index]
+
+
+class SelfAttention(nn.Module):
+    def __init__(self, input_dim):
+        super(SelfAttention, self).__init__()
+        self.input_dim = input_dim
+        self.query = nn.Linear(input_dim, input_dim)
+        self.key = nn.Linear(input_dim, input_dim)
+        self.value = nn.Linear(input_dim, input_dim)
+        self.softmax = nn.Softmax(dim=2)
+
+    def forward(self, x):
+        queries = self.query(x)
+        keys = self.key(x)
+        values = self.value(x)
+        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
+        attention = self.softmax(scores)
+        weighted = torch.bmm(attention, values)
+        out = torch.sum(weighted, axis=1)
+        return out
+
+
+class SelfAttention_multi(nn.Module):
+    def __init__(self, input_dim, n_head=1):
+        if input_dim % n_head != 0:
+            raise "Incompatible n_head"
+        super(SelfAttention_multi, self).__init__()
+        self.input_dim = input_dim // n_head
+        self.n_head = n_head
+        self.query = []
+        self.key = []
+        self.value = []
+        self.query = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
+        self.key = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
+        self.value = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
+        self.softmax = nn.Softmax(dim=2)
+
+    def forward(self, x):
+        q = []
+        k = []
+        v = []
+        for i in range(self.n_head):
+            q.append(self.query[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
+            k.append(self.key[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
+            v.append(self.value[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
+
+
+        queries = torch.cat(q, dim=2)
+        keys = torch.cat(k, dim=2)
+        values = torch.cat(v, dim=2)
+        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
+        attention = self.softmax(scores)
+        weighted = torch.bmm(attention, values)
+        out = torch.sum(weighted, axis=1)
+        return out
+
+
+class SelfAttention_multi_no_sum(nn.Module):
+    def __init__(self, input_dim, n_head=1):
+        if input_dim % n_head != 0:
+            raise "Incompatible n_head"
+        super(SelfAttention_multi_no_sum, self).__init__()
+        self.input_dim = input_dim // n_head
+        self.n_head = n_head
+        self.query = []
+        self.key = []
+        self.value = []
+        self.query = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
+        self.key = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
+        self.value = nn.ModuleList([nn.Linear(self.input_dim, self.input_dim) for _ in range(self.n_head)])
+        self.softmax = nn.Softmax(dim=2)
+
+    def forward(self, x):
+        q = []
+        k = []
+        v = []
+        for i in range(self.n_head):
+            q.append(self.query[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
+            k.append(self.key[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
+            v.append(self.value[i](x[:, :, self.input_dim * i:self.input_dim * (i + 1)]))
+
+
+        queries = torch.cat(q, dim=2)
+        keys = torch.cat(k, dim=2)
+        values = torch.cat(v, dim=2)
+        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
+        attention = self.softmax(scores)
+        weighted = torch.bmm(attention, values)
+        return weighted
+
+
+class EncoderBlock(nn.Module):
+
+    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0, *args, **kwargs):
+        """
+        Inputs:
+            input_dim - Dimensionality of the input
+            num_heads - Number of heads to use in the attention block
+            dim_feedforward - Dimensionality of the hidden layer in the MLP
+            dropout - Dropout probability to use in the dropout layers
+        """
+
+        # Attention layer
+        super().__init__(*args, **kwargs)
+        self.self_attn = SelfAttention_multi_no_sum(input_dim, num_heads)
+
+        # Two-layer MLP
+        self.linear_net = nn.Sequential(
+            nn.Linear(input_dim, dim_feedforward),
+            nn.Dropout(dropout),
+            nn.GELU(),
+            nn.Linear(dim_feedforward, input_dim)
+        )
+
+        # Layers to apply in between the main layers
+        self.norm1 = nn.LayerNorm(input_dim)
+        self.norm2 = nn.LayerNorm(input_dim)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, x):
+        # Attention part
+        x_n = self.norm1(x)
+        attn_out = self.self_attn(x_n)
+        x = x + self.dropout(attn_out)
+
+        # MLP part
+        x_n = self.norm2(x)
+        linear_out = self.linear_net(x_n)
+        x = x + self.dropout(linear_out)
+
+        return x
+
+
+class TransformerEncoder(nn.Module):
+
+    def __init__(self, num_layers, **block_args):
+        super().__init__()
+        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])
+
+    def forward(self, x):
+        for l in self.layers:
+            x = l(x)
+        return x
+
+    def get_attention_maps(self, x, mask=None):
+        attention_maps = []
+        for l in self.layers:
+            _, attn_map = l.self_attn(x, mask=mask, return_attention=True)
+            attention_maps.append(attn_map)
+            x = l(x)
+        return attention_maps
+
+
+class PositionalEncoding(nn.Module):
+
+    def __init__(self, d_model, max_len=5000):
+        """
+        Inputs
+            d_model - Hidden dimensionality of the input.
+            max_len - Maximum length of a sequence to expect.
+        """
+        super().__init__()
+
+        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
+        pe = torch.zeros(max_len, d_model)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0)
+
+        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
+        # Used for tensors that need to be on the same device as the module.
+        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
+        self.register_buffer('pe', pe, persistent=False)
+
+    def forward(self, x):
+        x = x + self.pe[:, :x.size(1)]
+        return x
diff --git a/local_integration_msms.py b/local_integration_msms.py
new file mode 100644
index 0000000000000000000000000000000000000000..0015141dd22bf5c8f62930ca9a2aea60aa34307b
--- /dev/null
+++ b/local_integration_msms.py
@@ -0,0 +1,27 @@
+import pyopenms as oms
+import numpy as np
+
+def compute_chromatograms(rt,charge,intensity,start_c,end_c):
+    value=[]
+
+    for k in range(len(rt)):
+        c = np.array(charge[k])
+        i = np.array(intensity[k])
+        value.append(np.sum(np.where(end_c > c > start_c, i, 0)))
+
+    return value
+
+
+if __name__ == "__main__":
+    e = oms.MSExperiment()
+    oms.MzMLFile().load("data/STAPH140.mzML", e)
+    e.updateRanges()
+    rt = []
+    charge = []
+    intensity = []
+    for s in e :
+        if s.getMSLevel() == 1:
+            rt.append(s.getRT())
+            charge.append(s.get_peaks()[0])
+            intensity.append(s.get_peaks()[1])
+    val = compute_chromatograms(rt, charge, intensity, 400. ,400.5)
\ No newline at end of file
diff --git a/logs/train/events.out.tfevents.1711626488.r9i4n1.955314.1.v2 b/logs/train/events.out.tfevents.1711626488.r9i4n1.955314.1.v2
new file mode 100644
index 0000000000000000000000000000000000000000..584e505bece0331a7e4639c6ad29a48dfe43007c
Binary files /dev/null and b/logs/train/events.out.tfevents.1711626488.r9i4n1.955314.1.v2 differ
diff --git a/logs/validation/events.out.tfevents.1711626506.r9i4n1.955314.2.v2 b/logs/validation/events.out.tfevents.1711626506.r9i4n1.955314.2.v2
new file mode 100644
index 0000000000000000000000000000000000000000..f4523bad4d517f8b0fa8c9e46ffb66af92098b09
Binary files /dev/null and b/logs/validation/events.out.tfevents.1711626506.r9i4n1.955314.2.v2 differ
diff --git a/logs_lr/train/events.out.tfevents.1711637459.r3i4n0.2741367.1.v2 b/logs_lr/train/events.out.tfevents.1711637459.r3i4n0.2741367.1.v2
new file mode 100644
index 0000000000000000000000000000000000000000..589c43dcbee11e17fca447953a7dc3a745cec51c
Binary files /dev/null and b/logs_lr/train/events.out.tfevents.1711637459.r3i4n0.2741367.1.v2 differ
diff --git a/logs_lr/train/events.out.tfevents.1711640539.r3i5n0.851015.1.v2 b/logs_lr/train/events.out.tfevents.1711640539.r3i5n0.851015.1.v2
new file mode 100644
index 0000000000000000000000000000000000000000..e7276e7bcf41182c5060bec1c952d03d3e006dc8
Binary files /dev/null and b/logs_lr/train/events.out.tfevents.1711640539.r3i5n0.851015.1.v2 differ
diff --git a/logs_lr/train/events.out.tfevents.1711642953.r10i3n0.4169735.1.v2 b/logs_lr/train/events.out.tfevents.1711642953.r10i3n0.4169735.1.v2
new file mode 100644
index 0000000000000000000000000000000000000000..8d8f9a1f92437ff8b806c1a20ecb100279649d99
Binary files /dev/null and b/logs_lr/train/events.out.tfevents.1711642953.r10i3n0.4169735.1.v2 differ
diff --git a/logs_lr/validation/events.out.tfevents.1711637474.r3i4n0.2741367.2.v2 b/logs_lr/validation/events.out.tfevents.1711637474.r3i4n0.2741367.2.v2
new file mode 100644
index 0000000000000000000000000000000000000000..43a03d61be6acf0f925618e0ec276fa9cdab5856
Binary files /dev/null and b/logs_lr/validation/events.out.tfevents.1711637474.r3i4n0.2741367.2.v2 differ
diff --git a/logs_lr/validation/events.out.tfevents.1711640556.r3i5n0.851015.2.v2 b/logs_lr/validation/events.out.tfevents.1711640556.r3i5n0.851015.2.v2
new file mode 100644
index 0000000000000000000000000000000000000000..f6040043d4e05a2312f093ad22627023008da154
Binary files /dev/null and b/logs_lr/validation/events.out.tfevents.1711640556.r3i5n0.851015.2.v2 differ
diff --git a/logs_lr/validation/events.out.tfevents.1711642969.r10i3n0.4169735.2.v2 b/logs_lr/validation/events.out.tfevents.1711642969.r10i3n0.4169735.2.v2
new file mode 100644
index 0000000000000000000000000000000000000000..b1f595aad3d27cdea567e88dd2adc6da13ea6fb7
Binary files /dev/null and b/logs_lr/validation/events.out.tfevents.1711642969.r10i3n0.4169735.2.v2 differ
diff --git a/loss.py b/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0021caf94bb9fb8eb90572c0c6f9ad012937c2d5
--- /dev/null
+++ b/loss.py
@@ -0,0 +1,86 @@
+import torch
+import torch.nn.functional as F
+from torchmetrics.regression import PearsonCorrCoef
+import numpy as np
+
+
+def masked_cos_sim(y_true, y_pred):
+    """Masked, cosine similarity between true and pred vectors
+
+    """
+
+    # To avoid numerical instability during training on GPUs,
+    # we add a fuzzing constant epsilon of 1×10−7 to all vectors
+    epsilon = 1e-7
+    # Masking: we multiply values by (true + 1) because then the peaks that cannot
+    # be there (and have value of -1 as explained above) won't be considered
+    pred_masked = ((y_true + 1) * y_pred) / (y_true + 1 + epsilon)
+    true_masked = ((y_true + 1) * y_true) / (y_true + 1 + epsilon)
+    pred_masked = F.normalize(pred_masked, p=2, dim=1)
+    true_masked = F.normalize(true_masked, p=2, dim=1)
+    return -(pred_masked * true_masked).sum(dim=1).mean()
+
+
+def masked_spectral_angle(y_true, y_pred):
+    """Masked, cosine similarity between true and pred vectors
+
+    """
+
+    # To avoid numerical instability during training on GPUs,
+    # we add a fuzzing constant epsilon of 1×10−7 to all vectors
+    epsilon = 1e-7
+    # Masking: we multiply values by (true + 1) because then the peaks that cannot
+    # be there (and have value of -1 as explained above) won't be considered
+    pred_masked = ((y_true + 1) * y_pred) / (y_true + 1 + epsilon)
+    true_masked = ((y_true + 1) * y_true) / (y_true + 1 + epsilon)
+    pred_masked = F.normalize(pred_masked, p=2, dim=1)
+    true_masked = F.normalize(true_masked, p=2, dim=1)
+    # print(pred_masked.sum(dim=1))
+    # print((pred_masked * true_masked).sum(dim=1).shape)
+    return 1 -2 * torch.acos((pred_masked * true_masked).sum(dim=1)).mean() / np.pi
+
+
+def masked_pearson_correlation_distance(y_true, y_pred, reduce='mean'):
+    """
+    Calculates the masked Pearson correlation distance between true and predicted intensity vectors.
+
+    The masked Pearson correlation distance is a metric for comparing the similarity between two intensity vectors,
+    taking into account only the non-negative values in the true values tensor (which represent valid peaks).
+
+    Parameters:
+    -----------
+    y_true : Tensor
+        A tensor containing the true values, with shape `(batch_size, num_values)`.
+    y_pred : Tensor
+        A tensor containing the predicted values, with the same shape as `y_true`.
+
+    Returns:
+    --------
+    Tensor
+        A tensor containing the masked Pearson correlation distance between `y_true` and `y_pred`.
+
+    Raises:
+    -------
+    ValueError
+        If `y_true` and `y_pred` have different shapes.
+    """
+    epsilon = 1e-7
+
+    # Masking: we multiply values by (true + 1) because then the peaks that cannot
+    # be there (and have value of -1 as explained above) won't be considered
+    pred_masked = ((y_true + 1) * y_pred) / (y_true + 1 + epsilon)
+    true_masked = ((y_true + 1) * y_true) / (y_true + 1 + epsilon)
+    loss = PearsonCorrCoef()
+    if reduce == 'mean':
+        return torch.mean(loss(pred_masked, true_masked))
+    if reduce == 'sum':
+        return torch.sum(loss(pred_masked, true_masked))
+    if reduce is None:
+        return loss(pred_masked, true_masked)
+
+def distance(x, y):
+    return torch.mean(torch.abs(x - y))
+
+
+def cos_sim_to_sa(cos):
+    return 1 - (2 * np.arccos(cos) / np.pi)
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..bff9cd6e14c75e1d8f1ae76bed652686956a71f9
--- /dev/null
+++ b/main.py
@@ -0,0 +1,372 @@
+import os
+
+import torch
+import torch.optim as optim
+# import wandb as wdb
+import numpy as np
+
+from config import load_args
+from dataloader import load_data, load_intensity_from_files
+from loss import masked_cos_sim, distance, masked_spectral_angle
+from model import (RT_pred_model_self_attention, Intensity_pred_model_multi_head, RT_pred_model_self_attention_pretext,
+                   RT_pred_model_self_attention_multi, RT_pred_model_self_attention_multi_sum,
+                   RT_pred_model_transformer)
+
+
+# from torcheval.metrics import R2Score
+
+
+# def compute_metrics(model, data_val, f_name):
+#     name = os.path.join('checkpoints', f_name)
+#     model.load_state_dict(torch.load(name))
+#     model.eval()
+#     targets = []
+#     preds = []
+#     r2 = R2Score()
+#     for data, target in data_val:
+#         targets.append(target)
+#         pred = model(data)
+#         preds.append(pred)
+#     full_target = torch.concat(targets, dim=0)
+#     full_pred = torch.concat(preds, dim=0)
+#
+#     r2.update(full_pred, full_target)
+#     diff = torch.abs(full_target - full_pred)
+#     sorted_diff, _ = diff.sort()
+#     delta_95 = sorted_diff[int(np.floor(sorted_diff.size(dim=0) * 0.95))].item()
+#     score = r2.compute()
+#     return score, delta_95
+
+
+def train_rt(model, data_train, epoch, optimizer, criterion, metric, wandb=None):
+    losses = 0.
+    dist_acc = 0.
+    model.train()
+    for param in model.parameters():
+        param.requires_grad = True
+    for data, target in data_train:
+        if torch.cuda.is_available():
+            data, target = data.cuda(), target.cuda()
+        pred_rt = model.forward(data)
+        target.float()
+        loss = criterion(pred_rt, target)
+        dist = metric(pred_rt, target)
+        dist_acc += dist.item()
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+        losses += loss.item()
+
+    if wandb is not None:
+        wdb.log({"train loss": losses / len(data_train), "train mean metric": dist_acc / len(data_train),
+                 'train epoch': epoch})
+
+    print('epoch : ', epoch, ',train losses : ', losses / len(data_train), " ,mean metric : ",
+          dist_acc / len(data_train))
+
+
+def train_pretext(model, data_train, epoch, optimizer, criterion, task, metric, coef, wandb=None):
+    losses, losses_2 = 0., 0.
+    dist_acc = 0.
+    model.train()
+    for param in model.parameters():
+        param.requires_grad = True
+    for data, target in data_train:
+        if torch.cuda.is_available():
+            data, target = data.cuda(), target.cuda()
+        pred_rt, pred_seq = model.forward(data)
+        pred_seq = pred_seq.transpose(1, 2)
+        target.float()
+        loss = criterion(pred_rt, target)
+        loss_2 = task(pred_seq, data)
+        losses_2 += loss_2.item()
+        loss_tot = loss + coef * loss_2
+        dist = metric(pred_rt, target)
+        dist_acc += dist.item()
+        optimizer.zero_grad()
+        loss_tot.backward()
+        optimizer.step()
+        losses += loss.item()
+
+    if wandb is not None:
+        wdb.log({"train loss": losses / len(data_train), "train loss pretext": losses_2 / len(data_train),
+                 "train mean metric": dist_acc / len(data_train), 'train epoch': epoch})
+
+    print('epoch : ', epoch, ',train losses : ', losses / len(data_train), ',train pretext losses : ',
+          losses_2 / len(data_train), " ,mean metric : ",
+          dist_acc / len(data_train))
+
+
+def train_int(model, data_train, epoch, optimizer, criterion, metric, wandb=None):
+    losses = 0.
+    dist_acc = 0.
+    model.train()
+    for param in model.parameters():
+        param.requires_grad = True
+    for data1, data2, data3, target in data_train:
+        if torch.cuda.is_available():
+            data1, data2, data3, target = data1.cuda(), data2.cuda(), data3.cuda(), target.cuda()
+        pred_rt = model.forward(data1, data2, data3)
+        target.float()
+        loss = criterion(pred_rt, target)
+        dist = metric(pred_rt, target)
+        dist_acc += dist.item()
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+        losses += loss.item()
+
+    if wandb is not None:
+        wdb.log({"train loss": losses / len(data_train), "train mean metric": dist_acc / len(data_train),
+                 'train epoch': epoch})
+
+    print('epoch : ', epoch, 'train losses : ', losses / len(data_train), " mean metric : ",
+          dist_acc / len(data_train))
+
+
+def eval_int(model, data_val, epoch, criterion, metric, wandb=None):
+    losses = 0.
+    dist_acc = 0.
+    model.eval()
+    for param in model.parameters():
+        param.requires_grad = False
+    for data1, data2, data3, target in data_val:
+        if torch.cuda.is_available():
+            data1, data2, data3, target = data1.cuda(), data2.cuda(), data3.cuda(), target.cuda()
+        pred_rt = model.forward(data1, data2, data3)
+        loss = criterion(pred_rt, target)
+        losses += loss.item()
+        dist = metric(pred_rt, target)
+        dist_acc += dist.item()
+
+    if wandb is not None:
+        wdb.log({"eval loss": losses / len(data_val), 'eval epoch': epoch, "eval metric": dist_acc / len(data_val)})
+    print('epoch : ', epoch, ',eval losses : ', losses / len(data_val), " ,eval mean metric: :",
+          dist_acc / len(data_val))
+    return losses / len(data_val)
+
+
+def eval_rt(model, data_val, epoch, criterion, metric, wandb=None):
+    losses = 0.
+    dist_acc = 0.
+    model.eval()
+    for param in model.parameters():
+        param.requires_grad = False
+    for data, target in data_val:
+        if torch.cuda.is_available():
+            data, target = data.cuda(), target.cuda()
+        pred_rt = model(data)
+        loss = criterion(pred_rt, target)
+        losses += loss.item()
+        dist = metric(pred_rt, target)
+        dist_acc += dist.item()
+
+    if wandb is not None:
+        wdb.log({"eval loss": losses / len(data_val), 'eval epoch': epoch, "eval metric": dist_acc / len(data_val)})
+    print('epoch : ', epoch, ',eval losses : ', losses / len(data_val), " ,eval mean metric: :",
+          dist_acc / len(data_val))
+
+    return dist_acc / len(data_val)
+
+
+def eval_pretext(model, data_val, epoch, criterion, metric, wandb=None):
+    losses = 0.
+    dist_acc = 0.
+    model.eval()
+    for param in model.parameters():
+        param.requires_grad = False
+    for data, target in data_val:
+        if torch.cuda.is_available():
+            data, target = data.cuda(), target.cuda()
+        pred_rt, _ = model(data)
+        loss = criterion(pred_rt, target)
+        losses += loss.item()
+        dist = metric(pred_rt, target)
+        dist_acc += dist.item()
+
+    if wandb is not None:
+        wdb.log({"eval loss": losses / len(data_val), 'eval epoch': epoch, "eval metric": dist_acc / len(data_val)})
+    print('epoch : ', epoch, ',eval losses : ', losses / len(data_val), " ,eval mean metric:",
+          dist_acc / len(data_val))
+
+    return dist_acc / len(data_val)
+
+
+def save(model, optimizer, epoch, checkpoint_name):
+    print('\nModel Saving...')
+    os.makedirs('checkpoints', exist_ok=True)
+    torch.save(model, os.path.join('checkpoints', checkpoint_name))
+
+
+def load(path):
+    model = torch.load(os.path.join('checkpoints', path))
+    return model
+
+
+def run_rt(epochs, eval_inter, save_inter, model, data_train, data_val, optimizer, criterion, metric, wandb=None):
+    for e in range(1, epochs + 1):
+        train_rt(model, data_train, e, optimizer, criterion, metric, wandb=wandb)
+        if e % eval_inter == 0:
+            eval_rt(model, data_val, e, criterion, metric, wandb=wandb)
+        if e % save_inter == 0:
+            save(model, optimizer, epochs, 'model_self_attention_' + str(e) + '.pt')
+
+
+def run_pretext(epochs, eval_inter, model, data_train, data_val, data_test, optimizer, criterion, task, metric, coef,
+                wandb=None):
+    best_dist = 10000
+    best_epoch = 0
+    for e in range(1, epochs + 1):
+        train_pretext(model, data_train, e, optimizer, criterion, task, metric, coef, wandb=wandb)
+        if e % eval_inter == 0:
+            dist = eval_pretext(model, data_val, e, criterion, metric, wandb=wandb)
+            if dist < best_dist:
+                best_epoch = e
+                if wandb is not None:
+                    save(model, optimizer, epochs, 'model_self_attention_pretext_' + wandb + '.pt')
+                else:
+                    save(model, optimizer, epochs, 'model_self_attention_pretext.pt')
+
+    if wandb is not None:
+        model_final = load('model_self_attention_pretext_' + wandb + '.pt')
+    else:
+        model_final = load('model_self_attention_pretext.pt')
+    eval_pretext(model_final, data_test, 0, criterion, metric, wandb=wandb)
+    print('Best epoch : ' + str(best_epoch))
+
+
+def run_int(epochs, eval_inter, save_inter, model, data_train, data_val, optimizer, criterion, metric,
+            wandb=None):
+    for e in range(1, epochs + 1):
+        best_loss = 10000
+        best_epoch = 0
+        train_int(model, data_train, e, optimizer, criterion, metric, wandb=wandb)
+        if e % eval_inter == 0:
+            loss = eval_int(model, data_val, e, criterion, metric, wandb=wandb)
+        #     if loss < best_loss:
+        #         best_epoch = e
+        #         if wandb is not None:
+        #             save(model, optimizer, epochs, 'model_int' + wandb + '.pt')
+        #         else:
+        #             save(model, optimizer, epochs, 'model_int.pt')
+        # if wandb is not None:
+        #     model_final = load('model_int' + wandb + '.pt')
+        # else:
+        #     model_final = load('model_int.pt')
+        # print('Best epoch : ',e)
+
+
+def main_rt(args):
+    if args.wandb is not None:
+        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
+        os.environ["WANDB_MODE"] = "offline"
+        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
+
+        wdb.init(project="RT prediction", dir='./wandb_run', name=args.wandb)
+    print(args)
+    print('Cuda : ', torch.cuda.is_available())
+    if args.dataset_train == args.dataset_test:
+        data_train, data_val, data_test = load_data(batch_size=args.batch_size, n_train=args.n_train, n_test=args.n_test,
+                                                    data_sources=[args.dataset_train, args.dataset_train, args.dataset_train])
+    else:
+        data_train, data_val, data_test = load_data(batch_size=args.batch_size, n_train=args.n_train, n_test=args.n_test,
+                                                    data_sources=[args.dataset_train,args.dataset_train,args.dataset_test])
+    print('\nData loaded')
+    # if args.model == 'RT_self_att' :
+    #     model = RT_pred_model_self_attention()
+    if args.model == 'RT_multi':
+        model = RT_pred_model_self_attention_multi(recurrent_layers_sizes=(args.layers_sizes[0],args.layers_sizes[1],args.layers_size[2]), regressor_layer_size=args.layers_sizes[3])
+    if args.model == 'RT_self_att' or args.model == 'RT_multi_sum':
+        model = RT_pred_model_self_attention_multi_sum(n_head=args.n_head, recurrent_layers_sizes=(args.layers_sizes[0],args.layers_sizes[1]), regressor_layer_size=args.layers_sizes[2])
+    if args.model == 'RT_transformer':
+        model = RT_pred_model_transformer(regressor_layer_size=args.layers_sizes[2])
+    if torch.cuda.is_available():
+        model = model.cuda()
+    optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    print('\nModel initialised')
+    run_rt(args.epochs, args.eval_inter, args.save_inter, model, data_train, data_val, optimizer=optimizer,
+           criterion=torch.nn.MSELoss(), metric=distance, wandb=args.wandb)
+
+    if args.wandb is not None:
+        wdb.finish()
+
+
+def main_pretext(args):
+    if args.wandb is not None:
+        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
+        os.environ["WANDB_MODE"] = "offline"
+        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
+
+        wdb.init(project="RT prediction", dir='./wandb_run', name=args.wandb)
+    print(args)
+    print('Cuda : ', torch.cuda.is_available())
+    if args.dataset_train == args.dataset_test:
+        data_train, data_val, data_test = load_data(args.batch_size, args.n_train, args.n_test,
+                                                    data_source=args.dataset_train)
+    else:
+        data_train, _, _ = load_data(args.batch_size, args.n_train, args.n_test,
+                                     data_source=args.dataset_train)
+        _, data_val, data_test = load_data(args.batch_size, args.n_train, args.n_test,
+                                           data_source=args.dataset_test)
+    print('\nData loaded')
+    model = RT_pred_model_self_attention_pretext(recurrent_layers_sizes=(args.layers_sizes[0],args.layers_sizes[1]), regressor_layer_size=args.layers_sizes[2])
+    if torch.cuda.is_available():
+        model = model.cuda()
+    optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    print('\nModel initialised')
+    run_pretext(args.epochs, args.eval_inter, model, data_train, data_val, data_test, optimizer=optimizer,
+                criterion=torch.nn.MSELoss(), task=torch.nn.CrossEntropyLoss(), metric=distance, coef=args.coef_pretext,
+                wandb=args.wandb)
+
+    if args.wandb is not None:
+        wdb.finish()
+
+
+def main_int(args):
+    if args.wandb is not None:
+        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
+
+        os.environ["WANDB_MODE"] = "offline"
+        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
+
+        wdb.init(project="Intensity prediction", dir='./wandb_run', name=args.wandb)
+    print(args)
+    print(torch.cuda.is_available())
+
+    sources_train = ('data/intensity/sequence_train.npy',
+                     'data/intensity/intensity_train.npy',
+                     'data/intensity/collision_energy_train.npy',
+                     'data/intensity/precursor_charge_train.npy')
+
+    sources_test = ('data/intensity/sequence_test.npy',
+                    'data/intensity/intensity_test.npy',
+                    'data/intensity/collision_energy_test.npy',
+                    'data/intensity/precursor_charge_test.npy')
+
+    data_train = load_intensity_from_files(sources_train[0], sources_train[1], sources_train[2], sources_train[3],
+                                           args.batch_size)
+    data_val = load_intensity_from_files(sources_test[0], sources_test[1], sources_test[2], sources_test[3],
+                                         args.batch_size)
+
+    print('\nData loaded')
+    model = Intensity_pred_model_multi_head(recurrent_layers_sizes=(args.layers_sizes[0],args.layers_sizes[1]), regressor_layer_size=args.layers_sizes[2])
+    if torch.cuda.is_available():
+        model = model.cuda()
+    optimizer = optim.Adam(model.parameters(), lr=0.001)
+    print('\nModel initialised')
+    run_int(args.epochs, args.eval_inter, args.save_inter, model, data_train, data_val, optimizer=optimizer,
+            criterion=masked_cos_sim, metric=masked_spectral_angle, wandb=args.wandb)
+
+    if args.wandb is not None:
+        wdb.finish()
+
+
+if __name__ == "__main__":
+    args = load_args()
+
+    if args.model == 'RT_self_att' or args.model == 'RT_multi' or args.model == 'RT_multi_sum' or args.model == 'RT_transformer':
+        main_rt(args)
+    elif args.model == 'Intensity_multi_head':
+        main_int(args)
+    elif args.model == 'RT_pretext':
+        main_pretext(args)
diff --git a/main_custom.py b/main_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..94bf357516ba9761d9aed6123c0def129c588cea
--- /dev/null
+++ b/main_custom.py
@@ -0,0 +1,253 @@
+import os
+import torch
+import torch.optim as optim
+import wandb as wdb
+
+import common_dataset
+import dataloader
+from config_common import load_args
+from common_dataset import load_data
+from dataloader import load_data
+from loss import masked_cos_sim, distance, masked_spectral_angle
+from model_custom import Model_Common_Transformer, Model_Common_Transformer_TAPE
+from model import RT_pred_model_self_attention_multi
+
+
+def train(model, data_train, epoch, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
+          wandb=None):
+    losses_rt = 0.
+    losses_int = 0.
+    dist_rt_acc = 0.
+    dist_int_acc = 0.
+    model.train()
+    for param in model.parameters():
+        param.requires_grad = True
+    if forward == 'both':
+        for seq, charge, rt, intensity in data_train:
+            rt, intensity = rt.float(), intensity.float()
+            if torch.cuda.is_available():
+                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
+            pred_rt, pred_int = model.forward(seq, charge)
+            loss_rt = criterion_rt(rt, pred_rt)
+            loss_int = criterion_intensity(intensity, pred_int)
+            loss = loss_rt + loss_int
+            dist_rt = metric_rt(rt, pred_rt)
+            dist_int = metric_intensity(intensity, pred_int)
+            dist_rt_acc += dist_rt.item()
+            dist_int_acc += dist_int.item()
+            losses_rt += loss_rt.item()
+            losses_int += 5.*loss_int.item()
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+        if wandb is not None:
+            wdb.log({"train rt loss": losses_rt / len(data_train), "train int loss": losses_int / len(data_train),
+                     "train rt mean metric": dist_rt_acc / len(data_train),
+                     "train int mean metric": dist_int_acc / len(data_train),
+                     'train epoch': epoch})
+
+        print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), 'train int loss',
+              losses_int / len(data_train), "train rt mean metric : ", dist_rt_acc / len(data_train),
+              "train int mean metric",
+              dist_int_acc / len(data_train))
+
+    if forward == 'rt':
+        for seq, rt in data_train:
+            rt = rt.float()
+            if torch.cuda.is_available():
+                seq, rt = seq.cuda(), rt.cuda()
+            pred_rt = model.forward_rt(seq)
+            loss_rt = criterion_rt(rt, pred_rt)
+            loss = loss_rt
+            dist_rt = metric_rt(rt, pred_rt)
+            dist_rt_acc += dist_rt.item()
+            losses_rt += loss_rt.item()
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+        if wandb is not None:
+            wdb.log({"train rt loss": losses_rt / len(data_train),
+                     "train rt mean metric": dist_rt_acc / len(data_train),
+                     'train epoch': epoch})
+
+        print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), "train rt mean metric : ",
+              dist_rt_acc / len(data_train))
+
+    if forward == 'int':
+        for seq, charge, intensity in data_train:
+            intensity = intensity.float()
+            if torch.cuda.is_available():
+                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
+            pred_int = model.forward_int(seq, charge)
+            loss_int = criterion_intensity(intensity, pred_int)
+            loss = loss_int
+            dist_int = metric_intensity(intensity, pred_int)
+            dist_int_acc += dist_int.item()
+            losses_int += loss_int.item()
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+        if wandb is not None:
+            wdb.log({"train int loss": losses_int / len(data_train),
+                     "train int mean metric": dist_int_acc / len(data_train),
+                     'train epoch': epoch})
+
+        print('epoch : ', epoch, 'train int loss',
+              losses_int / len(data_train),
+              "train int mean metric",
+              dist_int_acc / len(data_train))
+
+
+def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=None):
+    losses_rt = 0.
+    losses_int = 0.
+    dist_rt_acc = 0.
+    dist_int_acc = 0.
+    for param in model.parameters():
+        param.requires_grad = False
+    if forward == 'both':
+        for seq, charge, rt, intensity in data_val:
+            rt, intensity = rt.float(), intensity.float()
+            if torch.cuda.is_available():
+                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
+            pred_rt, pred_int = model.forward(seq, charge)
+            loss_rt = criterion_rt(rt, pred_rt)
+            loss_int = criterion_intensity(intensity, pred_int)
+            losses_rt += loss_rt.item()
+            losses_int += loss_int.item()
+            dist_rt = metric_rt(rt, pred_rt)
+            dist_int = metric_intensity(intensity, pred_int)
+            dist_rt_acc += dist_rt.item()
+            dist_int_acc += dist_int.item()
+
+        if wandb is not None:
+            wdb.log({"val rt loss": losses_rt / len(data_val), "val int loss": losses_int / len(data_val),
+                     "val rt mean metric": dist_rt_acc / len(data_val),
+                     "val int mean metric": dist_int_acc / len(data_val),
+                     'val epoch': epoch})
+
+        print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val), 'val int loss', losses_int / len(data_val),
+              "val rt mean metric : ",
+              dist_rt_acc / len(data_val), "val int mean metric", dist_int_acc / len(data_val))
+
+    if forward == 'rt':  #adapted to prosit dataset format
+        for seq, rt in data_val:
+            rt = rt.float()
+            if torch.cuda.is_available():
+                seq, rt = seq.cuda(), rt.cuda()
+            pred_rt = model.forward_rt(seq)
+            loss_rt = criterion_rt(rt, pred_rt)
+            losses_rt += loss_rt.item()
+            dist_rt = metric_rt(rt, pred_rt)
+            dist_rt_acc += dist_rt.item()
+
+        if wandb is not None:
+            wdb.log({"val rt loss": losses_rt / len(data_val),
+                     "val rt mean metric": dist_rt_acc / len(data_val),
+                     'val epoch': epoch})
+
+        print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val),
+              "val rt mean metric : ",
+              dist_rt_acc / len(data_val))
+
+    if forward == 'int':  #adapted to prosit dataset format
+        for seq, charge, _, intensity in data_val:
+            intensity = intensity.float()
+            if torch.cuda.is_available():
+                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
+            pred_int = model.forward_int(seq, charge)
+            loss_int = criterion_intensity(intensity, pred_int)
+            losses_int += loss_int.item()
+            dist_int = metric_intensity(intensity, pred_int)
+            dist_int_acc += dist_int.item()
+        if wandb is not None:
+            wdb.log({"val int loss": losses_int / len(data_val),
+                     "val int mean metric": dist_int_acc / len(data_val),
+                     'val epoch': epoch})
+        print('epoch : ', epoch, 'val int loss', losses_int / len(data_val), "val int mean metric",
+              dist_int_acc / len(data_val))
+
+
+def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
+        criterion_intensity, metric_rt, metric_intensity, forward, wandb=None):
+    for e in range(1, epochs + 1):
+        train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
+              wandb=wandb)
+        if e % eval_inter == 0:
+            eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
+                 wandb=wandb)
+        if e % save_inter == 0:
+            save(model, 'model_common_' + str(e) + '.pt')
+
+
+def main(args):
+    if args.wandb is not None:
+        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
+        os.environ["WANDB_MODE"] = "offline"
+        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
+
+        wdb.init(project="Common prediction", dir='./wandb_run', name=args.wandb)
+    print(args)
+    print('Cuda : ', torch.cuda.is_available())
+
+    if args.forward == 'both':
+        data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
+                                                                   path_val=args.dataset_val,
+                                                                   path_test=args.dataset_test,
+                                                                   batch_size=args.batch_size, length=25, pad = False, convert=True, vocab='iapuc')
+    elif args.forward == 'rt':
+        data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train,
+                                                               batch_size=args.batch_size, length=25)
+
+    print('\nData loaded')
+
+    model = Model_Common_Transformer_TAPE(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
+                                     decoder_int_ff=args.decoder_int_ff
+                                     , n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
+                                     decoder_int_num_layer=args.decoder_int_num_layer,
+                                     decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
+                                     embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first)
+    if torch.cuda.is_available():
+        model = model.cuda()
+    optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    print('\nModel initialised')
+
+    run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train,
+        data_val=data_val, data_test=data_test, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(),
+        criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle,
+        wandb=args.wandb, forward=args.forward)
+
+    if args.wandb is not None:
+        wdb.finish()
+
+
+def save(model, checkpoint_name):
+    print('\nModel Saving...')
+    os.makedirs('checkpoints', exist_ok=True)
+    torch.save(model, os.path.join('checkpoints', checkpoint_name))
+
+
+def load(path):
+    model = torch.load(os.path.join('checkpoints', path))
+    return model
+
+
+def get_n_params(model):
+    pp = 0
+    for n, p in list(model.named_parameters()):
+        nn = 1
+
+        for s in list(p.size()):
+            nn = nn * s
+        print(n, nn)
+        pp += nn
+    return pp
+
+
+if __name__ == "__main__":
+    args = load_args()
+    main(args)
+
diff --git a/main_mz_image.py b/main_mz_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..dacb688963ad99821110e55582b58a7609d1ba3e
--- /dev/null
+++ b/main_mz_image.py
@@ -0,0 +1,20 @@
+import numpy as np
+import pyopenms as oms
+from mzml_exploration import build_image_generic
+import os
+import glob
+
+
+DIR_NAME = 'data/mzml/*.mzml'
+BIN_SIZE = 2
+NUM_RT = 300
+
+if '__name__' == '__main__':
+
+    l = glob.glob(DIR_NAME, root_dir=None, dir_fd=None, recursive=False, include_hidden=False)
+    for f_name in l :
+        e = oms.MSExperiment()
+        oms.MzMLFile().load(f_name, e)
+        im = build_image_generic(e, BIN_SIZE, NUM_RT)
+        im2 = np.maximum(0, np.log(im + 1))
+        np.save(os.path.splitext('f_name')[0] + '.npy',im2)
diff --git a/main_ray_tune.py b/main_ray_tune.py
new file mode 100644
index 0000000000000000000000000000000000000000..50036dc1a665f38b9ae3041365521bff4498ceaa
--- /dev/null
+++ b/main_ray_tune.py
@@ -0,0 +1,341 @@
+import os
+import tempfile
+
+import torch
+import torch.optim as optim
+from ray.air import RunConfig, CheckpointConfig
+from ray.tune.search.ax import AxSearch
+from ray.tune.search.bayesopt import BayesOptSearch
+from ray.tune.search.bohb import TuneBOHB
+from ray.tune.search.optuna import OptunaSearch
+from ray.util.client import ray
+
+import common_dataset
+import dataloader
+from config_common import load_args
+from loss import masked_cos_sim
+from model_custom import Model_Common_Transformer
+from ray import train, tune
+from ray.train import Checkpoint
+from ray.tune.schedulers import HyperBandForBOHB, ASHAScheduler
+
+
+def train_model(config, args):
+    net = Model_Common_Transformer(encoder_ff=int(config["encoder_ff"]),
+                                   decoder_rt_ff=int(config["decoder_rt_ff"]),
+                                   decoder_int_ff=int(config["decoder_int_ff"]),
+                                   n_head=int(config["n_head"]),
+                                   encoder_num_layer=int(config["encoder_num_layer"]),
+                                   decoder_int_num_layer=int(config["decoder_int_num_layer"]),
+                                   decoder_rt_num_layer=int(config["decoder_rt_num_layer"]),
+                                   drop_rate=float(config["drop_rate"]),
+                                   embedding_dim=int(config["embedding_dim"]),
+                                   acti=config["activation"],
+                                   norm=config["norm_first"])
+
+    device = "cpu"
+    if torch.cuda.is_available():
+        device = "cuda:0"
+        if torch.cuda.device_count() > 1:
+            print(type(net))
+            net = torch.nn.DataParallel(net)
+            print(type(net))
+    net.to(device)
+
+    criterion_rt = torch.nn.MSELoss()
+    criterion_intensity = masked_cos_sim
+    optimizer = optim.Adam(net.parameters(), lr=config["lr"])
+
+    # Load existing checkpoint through `get_checkpoint()` API.
+    if train.get_checkpoint():
+        loaded_checkpoint = train.get_checkpoint()
+        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
+            model_state, optimizer_state = torch.load(
+                os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
+            )
+            net.load_state_dict(model_state)
+            optimizer.load_state_dict(optimizer_state)
+
+    if args.forward == 'both':
+        data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
+                                                                   path_val=args.dataset_val,
+                                                                   path_test=args.dataset_test,
+                                                                   batch_size=int(config["batch_size"]), length=25)
+    else:
+        data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train,
+                                                               batch_size=int(config["batch_size"]), length=25)
+
+    for epoch in range(100):  # loop over the dataset multiple times
+        running_loss = 0.0
+        epoch_steps = 0
+        for i, data in enumerate(data_train):
+
+            if args.forward == 'rt':
+                seq, rt = data
+                rt = rt.float()
+                if torch.cuda.is_available():
+                    seq, rt = seq.cuda(), rt.cuda()
+
+                if torch.cuda.device_count() > 1:
+                    pred_rt = net.module.forward_rt(seq)
+                else:
+                    pred_rt = net.forward_rt(seq)
+
+                loss = criterion_rt(rt, pred_rt)
+
+            elif args.forward == 'int':
+                seq, charge, intensity = data
+                intensity = intensity.float()
+                if torch.cuda.is_available():
+                    seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
+
+                if torch.cuda.device_count() > 1:
+                    pred_int = net.module.forward_int(seq, charge)
+                else:
+                    pred_int = net.forward_int(seq, charge)
+
+                loss = criterion_intensity(intensity, pred_int)
+
+            else:
+                seq, charge, rt, intensity = data
+                rt, intensity = rt.float(), intensity.float()
+                if torch.cuda.is_available():
+                    seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
+                pred_rt, pred_int = net(seq, charge)
+                loss_rt = criterion_rt(rt, pred_rt)
+                loss_int = criterion_intensity(intensity, pred_int)
+                loss = loss_rt + loss_int
+
+            running_loss += loss.item()
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+            # print statistics
+
+            epoch_steps += 1
+            if i % 2000 == 1999:  # print every 2000 mini-batches
+                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
+                                                running_loss / epoch_steps))
+                running_loss = 0.0
+
+        # Validation loss
+        val_loss = 0.0
+        val_steps = 0
+        for i, data in enumerate(data_val, 0):
+            with torch.no_grad():
+                if args.forward == 'rt':
+                    seq, rt = data
+                    rt = rt.float()
+                    if torch.cuda.is_available():
+                        seq, rt = seq.cuda(), rt.cuda()
+
+                    if torch.cuda.device_count() > 1:
+                        pred_rt = net.module.forward_rt(seq)
+                    else:
+                        pred_rt = net.forward_rt(seq)
+
+                    loss = criterion_rt(rt, pred_rt)
+
+                elif args.forward == 'int':
+                    seq, charge, intensity = data
+                    intensity = intensity.float()
+                    if torch.cuda.is_available():
+                        seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
+
+                    if torch.cuda.device_count() > 1:
+                        pred_int = net.module.forward_int(seq, charge)
+                    else:
+                        pred_int = net.forward_int(seq, charge)
+
+                    loss = criterion_intensity(intensity, pred_int)
+
+                else:
+                    seq, charge, rt, intensity = data
+                    rt, intensity = rt.float(), intensity.float()
+                    if torch.cuda.is_available():
+                        seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
+                    pred_rt, pred_int = net(seq, charge)
+                    loss_rt = criterion_rt(rt, pred_rt)
+                    loss_int = criterion_intensity(intensity, pred_int)
+                    loss = loss_rt + loss_int
+                val_loss += loss.item().numpy()
+                val_steps += 1
+
+        # Here we save a checkpoint. It is automatically registered with
+        # Ray Tune and will potentially be accessed through in ``get_checkpoint()``
+        # in future iterations.
+        # Note to save a file like checkpoint, you still need to put it under a directory
+        # to construct a checkpoint.
+        with tempfile.TemporaryDirectory(
+                dir='/gpfswork/rech/ute/ucg81ws/these/LC-MS-RT-prediction/checkpoints') as temp_checkpoint_dir:
+            path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
+
+            torch.save(
+                (net.state_dict(), optimizer.state_dict()), path
+            )
+            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
+            print(checkpoint.path)
+            train.report(
+                {"loss": (val_loss / val_steps)},
+                checkpoint=checkpoint,
+            )
+    print("Finished Training")
+
+
+def test_best_model(best_result, args):
+    best_trained_model = Model_Common_Transformer(encoder_ff=best_result.config["encoder_ff"],
+                                                  decoder_rt_ff=best_result.config["decoder_rt_ff"],
+                                                  decoder_int_ff=best_result.config["decoder_int_ff"]
+                                                  , n_head=best_result.config["n_head"],
+                                                  encoder_num_layer=best_result.config["batch_size"],
+                                                  decoder_int_num_layer=best_result.config["decoder_int_num_layer"],
+                                                  decoder_rt_num_layer=best_result.config["decoder_rt_num_layer"],
+                                                  drop_rate=best_result.config["drop_rate"],
+                                                  embedding_dim=best_result.config["embedding_dim"],
+                                                  acti=best_result.config["activation"],
+                                                  norm=best_result.config["norm_first"])
+
+    device = "cpu"
+    if torch.cuda.is_available():
+        device = "cuda:0"
+        if torch.cuda.device_count() > 1:
+            best_trained_model = torch.nn.DataParallel(best_trained_model)
+
+    best_trained_model.to(device)
+    criterion_rt = torch.nn.MSELoss()
+    criterion_intensity = masked_cos_sim
+    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
+
+    model_state, optimizer_state = torch.load(checkpoint_path)
+    best_trained_model.load_state_dict(model_state)
+
+    if args.forward == 'both':
+        data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
+                                                                   path_val=args.dataset_val,
+                                                                   path_test=args.dataset_test,
+                                                                   batch_size=best_result.config["batch_size"],
+                                                                   length=25)
+    else:
+        data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train,
+                                                               batch_size=best_result.config["batch_size"], length=25)
+    val_loss = 0
+    val_steps = 0
+    with torch.no_grad():
+        for data in data_test:
+            if args.forward == 'rt':
+                seq, rt = data
+                rt = rt.float()
+                if torch.cuda.is_available():
+                    seq, rt = seq.cuda(), rt.cuda()
+
+                if torch.cuda.device_count() > 1:
+                    pred_rt = best_trained_model.module.forward_rt(seq)
+                else:
+                    pred_rt = best_trained_model.forward_rt(seq)
+
+                loss = criterion_rt(rt, pred_rt)
+
+            elif args.forward == 'int':
+                seq, charge, intensity = data
+                intensity = intensity.float()
+                if torch.cuda.is_available():
+                    seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
+
+                if torch.cuda.device_count() > 1:
+                    pred_int = best_trained_model.module.forward_int(seq, charge)
+                else:
+                    pred_int = best_trained_model.forward_int(seq, charge)
+
+                loss = criterion_intensity(intensity, pred_int)
+
+            elif args.forward == 'both':
+                seq, charge, rt, intensity = data
+                rt, intensity = rt.float(), intensity.float()
+                if torch.cuda.is_available():
+                    seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
+                pred_rt, pred_int = best_trained_model(seq, charge)
+                loss_rt = criterion_rt(rt, pred_rt)
+                loss_int = criterion_intensity(intensity, pred_int)
+                loss = loss_rt + loss_int
+            val_loss += loss.item().numpy()
+            val_steps += 1
+    print("Best trial test set AsyncHyperBandSchedulerloss: {}".format(val_loss))
+
+
+def main(args, gpus_per_trial=1):
+    # config = {
+    #     "encoder_num_layer": tune.choice([1]),
+    #     "decoder_rt_num_layer": tune.choice([1]),
+    #     "decoder_int_num_layer": tune.choice([1]),
+    #     "embedding_dim": tune.choice([16, 64, 256, 1024]),
+    #     "encoder_ff": tune.choice([512]),
+    #     "decoder_rt_ff": tune.choice([512]),
+    #     "decoder_int_ff": tune.choice([512]),
+    #     "n_head": tune.choice([1]),
+    #     "drop_rate": tune.choice([0.2]),
+    #     "lr": tune.choice([1e-4]),
+    #     "batch_size": tune.choice([1024]),
+    # }
+    config = {
+        "encoder_num_layer": tune.choice([2, 4, 8]),
+        "decoder_rt_num_layer": tune.choice([2, 4, 8]),
+        "decoder_int_num_layer": tune.choice([1]),
+        "embedding_dim": tune.choice([16, 64]),
+        "encoder_ff": tune.choice([512, 1024, 2048]),
+        "decoder_rt_ff": tune.choice([512, 1024, 2048]),
+        "decoder_int_ff": tune.choice([512]),
+        "n_head": tune.choice([1, 2, 4, 8, 16]),
+        "drop_rate": tune.choice([0.25]),
+        "lr": tune.loguniform(1e-4, 1e-2),
+        "batch_size": tune.choice([4096]),
+        "activation": tune.choice(['relu', 'gelu']),
+        "norm_first": tune.choice([True, False]),
+    }
+    scheduler = ASHAScheduler(
+        max_t=100,
+        grace_period=30,
+        reduction_factor=3,
+        brackets=1,
+    )
+    algo = OptunaSearch()
+
+    tuner = tune.Tuner(
+        tune.with_resources(
+            tune.with_parameters(train_model, args=args),
+            resources={"cpu": 80, "gpu": gpus_per_trial}
+        ),
+        tune_config=tune.TuneConfig(
+            time_budget_s=3600 * 23,
+            search_alg=algo,
+            scheduler=scheduler,
+            num_samples=20,
+            metric='loss',
+            mode='min',
+
+
+        ),
+        run_config=RunConfig(storage_path="/gpfswork/rech/ute/ucg81ws/these/LC-MS-RT-prediction/ray_results_test",
+                             name="test_experiment_no_scheduler"
+                             ),
+        param_space=config
+
+    )
+    results = tuner.fit()
+
+    best_result = results.get_best_result("loss", "min")
+
+    print("Best trial config: {}".format(best_result.config))
+    print("Best trial final validation loss: {}".format(
+        best_result.metrics["loss"]))
+    print("Best trial final validation accuracy: {}".format(
+        best_result.metrics["accuracy"]))
+
+    test_best_model(best_result, args)
+
+
+if __name__ == "__main__":
+    for i in range(torch.cuda.device_count()):
+        print(torch.cuda.get_device_properties(i).name)
+    torch.manual_seed(2809)
+    arg = load_args()
+    main(arg, gpus_per_trial=4)
diff --git a/mass_prediction.py b/mass_prediction.py
new file mode 100644
index 0000000000000000000000000000000000000000..603e1fd1f259b3912f442edb68273b839ac07769
--- /dev/null
+++ b/mass_prediction.py
@@ -0,0 +1,122 @@
+# MASS CST DICT
+import numpy as np
+
+MASSES_MONO = {
+    "A": 71.03711,
+    "C": 103.00919,
+    "D": 115.02694,
+    "E": 129.04259,
+    "F": 147.06841,
+    "G": 57.02146,
+    "H": 137.05891,
+    "I": 113.08406,
+    "K": 128.09496,
+    "L": 113.08406,
+    "M": 131.04049,
+    "N": 114.04293,
+    "P": 97.05276,
+    "Q": 128.05858,
+    "R": 156.1875,
+    "S": 87.03203,
+    "T": 101.04768,
+    "V": 99.06841,
+    "W": 186.07931,
+    "Y": 163.06333,
+}
+
+MASSES_AVG = {
+    "A": 71.0788,
+    "C": 103.1388,
+    "D": 115.0886,
+    "E": 129.1155,
+    "F": 147.1766,
+    "G": 57.0519,
+    "H": 137.1411,
+    "I": 113.1594,
+    "K": 128.1741,
+    "L": 113.1594,
+    "M": 131.1926,
+    "N": 114.1038,
+    "P": 97.1167,
+    "Q": 128.1307,
+    "R": 156.1875,
+    "S": 87.0782,
+    "T": 101.1051,
+    "V": 99.1326,
+    "W": 186.2132,
+    "Y": 163.1760,
+}
+
+PTMs_MON0 = {
+    'Alkylation': 14.01564,
+    'Carbamylation': 43.00581,
+    'Carboxymethyl cysteine (Cys_CM)': 161.01466,
+    'Carboxyamidomethyl cysteine (Cys_CAM)': 160.03065,
+    'Pyridyl-ethyl cysteine (Cys_PE)': 208.067039,
+    'Propionamide cysteine (Cys_PAM)': 174.04631,
+    'Methionine sulfoxide (MSO)': 147.0354,
+    'Oxydized tryptophan (TPO)': 202.0742,
+    'Homoserine Lactone (HSL)': 100.03985,
+    'H': 1.00783,
+    'H+': 1.00728,
+    'O': 15.9949146,
+    'H2O': 18.01056,
+}
+
+PTMs_AVG = {
+    'Alkylation': 14.02688,
+    'Carbamylation': 43.02502,
+    'Carboxymethyl cysteine (Cys_CM)': 161.1755,
+    'Carboxyamidomethyl cysteine (Cys_CAM)': 160.1908,
+    'Pyridyl-ethyl cysteine (Cys_PE)': 208.284,
+    'Propionamide cysteine (Cys_PAM)': 174.2176,
+    'Methionine sulfoxide (MSO)': 147.1920,
+    'Oxydized tryptophan (TPO)': 202.2126,
+    'Homoserine Lactone (HSL)': 100.09714,
+    'H': 1.00794,
+    'H+': 1.00739,
+    'O': 15.9994,
+    'H2O': 18.01524,
+}
+
+
+def compute_mass(seq, isotop, mod=False):
+    m = 0
+    if mod == False:
+        if isotop == 'mono':
+            for char in MASSES_MONO.keys():
+                m += MASSES_MONO[char] * seq.count(char)
+        if isotop == 'avg':
+            for char in MASSES_AVG.keys():
+                m += MASSES_AVG[char] * seq.count(char)
+    else:
+        if isotop == 'mono':
+            for char in MASSES_MONO.keys():  # TODO mod
+                m += MASSES_MONO[char] * seq.count(char)
+        if isotop == 'avg':
+            for char in MASSES_AVG.keys():  # TODO mod
+                m += MASSES_AVG[char] * seq.count(char)
+    return m
+
+
+def compute_frag_mz_ration(seq, isotop, mod=False):
+    masses = np.array([-1] * 174)
+    acc_b = 0
+    acc_y = 0
+    n = len(seq)
+
+    # TODO mod
+    for i in range(n - 1):
+        if isotop == 'avg':
+            acc_b += MASSES_AVG[seq[i - 1]]
+            acc_y += MASSES_AVG[seq[n - 1 - i]]
+        if isotop == 'mono':
+            acc_b += MASSES_MONO[seq[i - 1]]
+            acc_y += MASSES_MONO[seq[n - 1 - i]]
+        masses[6 * i ] = acc_y
+        masses[6 * i + 1] = acc_y / 2
+        masses[6 * i + 2] = acc_y / 3
+        masses[6 * i + 3] = acc_b
+        masses[6 * i + 4] = acc_b / 2
+        masses[6 * i + 5] = acc_b / 3
+    return masses
diff --git a/metrics_lr/events.out.tfevents.1711637459.r3i4n0.2741367.0.v2 b/metrics_lr/events.out.tfevents.1711637459.r3i4n0.2741367.0.v2
new file mode 100644
index 0000000000000000000000000000000000000000..939d57cce34402e2dd183c9875880f78f1e316ff
Binary files /dev/null and b/metrics_lr/events.out.tfevents.1711637459.r3i4n0.2741367.0.v2 differ
diff --git a/metrics_lr/events.out.tfevents.1711640539.r3i5n0.851015.0.v2 b/metrics_lr/events.out.tfevents.1711640539.r3i5n0.851015.0.v2
new file mode 100644
index 0000000000000000000000000000000000000000..2c511e302707d4af159c6b4fdb9e8b8a912090ad
Binary files /dev/null and b/metrics_lr/events.out.tfevents.1711640539.r3i5n0.851015.0.v2 differ
diff --git a/metrics_lr/events.out.tfevents.1711642953.r10i3n0.4169735.0.v2 b/metrics_lr/events.out.tfevents.1711642953.r10i3n0.4169735.0.v2
new file mode 100644
index 0000000000000000000000000000000000000000..ebc26e4e43bf5a801ad381af425317fb26187a0b
Binary files /dev/null and b/metrics_lr/events.out.tfevents.1711642953.r10i3n0.4169735.0.v2 differ
diff --git a/model.py b/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b9b12e91abf63e65dc69d95d5c0b0354a070b13
--- /dev/null
+++ b/model.py
@@ -0,0 +1,385 @@
+import numpy as np
+import torch.nn as nn
+import torch
+
+from layers import SelectItem, SelfAttention_multi, SelfAttention, TransformerEncoder
+
+
+class RT_pred_model(nn.Module):
+
+    def __init__(self, drop_rate):
+        super(RT_pred_model, self).__init__()
+        self.encoder = nn.Sequential(
+            nn.GRU(input_size=8, hidden_size=16, num_layers=2, dropout=drop_rate, bidirectional=True, batch_first=True),
+            SelectItem(1),
+            nn.Dropout(p=drop_rate)
+        )
+
+        self.decoder = nn.Sequential(
+            nn.Linear(64, 16),
+            nn.ReLU(),
+            nn.Dropout(p=drop_rate),
+            nn.Linear(16, 8),
+            nn.ReLU(),
+            nn.Dropout(p=drop_rate),
+            nn.Linear(8, 1)
+        )
+
+        self.emb = nn.Linear(24, 8)
+
+        self.encoder.float()
+        self.decoder.float()
+        self.emb.float()
+
+    def forward(self, seq):
+        x = torch.nn.functional.one_hot(seq, 24)
+        x_emb = self.emb(x.float())
+        x_enc = self.encoder(x_emb)
+        x_enc = x_enc.swapaxes(0, 1)
+        x_enc = torch.flatten(x_enc, start_dim=1)
+        x_rt = self.decoder(x_enc)
+        x_rt = torch.flatten(x_rt)
+        return x_rt
+
+
+# To remove if multi_sum works
+class RT_pred_model_self_attention(nn.Module):
+
+    def __init__(self, drop_rate=0.5, embedding_output_dim=16, nb_aa=23, latent_dropout_rate=0.1,
+                 regressor_layer_size=512,
+                 recurrent_layers_sizes=(256, 512), ):
+        self.drop_rate = drop_rate
+        self.regressor_layer_size = regressor_layer_size
+        self.latent_dropout_rate = latent_dropout_rate
+        self.recurrent_layers_sizes = recurrent_layers_sizes
+        self.nb_aa = nb_aa
+        self.embedding_output_dim = embedding_output_dim
+        super(RT_pred_model_self_attention, self).__init__()
+        self.encoder = nn.Sequential(
+            nn.GRU(input_size=embedding_output_dim, hidden_size=self.recurrent_layers_sizes[0], num_layers=1,
+                   dropout=self.drop_rate,
+                   bidirectional=True,
+                   batch_first=True),
+            SelectItem(0),
+            nn.ReLU(),
+            nn.Dropout(p=self.drop_rate),
+            nn.GRU(input_size=self.recurrent_layers_sizes[0] * 2, hidden_size=self.recurrent_layers_sizes[1],
+                   num_layers=1, dropout=self.drop_rate, bidirectional=True,
+                   batch_first=True),
+            SelectItem(0),
+            nn.Dropout(p=drop_rate),
+        )
+
+        self.decoder = nn.Sequential(
+            SelfAttention_multi(self.recurrent_layers_sizes[1] * 2, 1),
+            nn.Linear(self.recurrent_layers_sizes[1] * 2, regressor_layer_size),
+            nn.ReLU(),
+            nn.Dropout(p=self.latent_dropout_rate),
+            nn.Linear(regressor_layer_size, 1)
+        )
+
+        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)
+
+        self.encoder.float()
+        self.decoder.float()
+        self.emb.float()
+
+    def forward(self, seq):
+        x = torch.nn.functional.one_hot(seq, self.nb_aa)
+        x_emb = self.emb(x.float())
+        x_enc = self.encoder(x_emb)
+        x_rt = self.decoder(x_enc)
+        x_rt = torch.flatten(x_rt)
+        return x_rt
+
+
+class RT_pred_model_self_attention_multi(nn.Module):
+
+    def __init__(self, drop_rate=0.5, embedding_output_dim=16, nb_aa=23, latent_dropout_rate=0.1,
+                 regressor_layer_size=512,
+                 recurrent_layers_sizes=(256, 512, 512), n_head=8):
+        self.drop_rate = drop_rate
+        self.n_head = n_head
+        self.regressor_layer_size = regressor_layer_size
+        self.latent_dropout_rate = latent_dropout_rate
+        self.recurrent_layers_sizes = recurrent_layers_sizes
+        self.nb_aa = nb_aa
+        self.embedding_output_dim = embedding_output_dim
+        super(RT_pred_model_self_attention_multi, self).__init__()
+        self.encoder = nn.Sequential(
+            nn.GRU(input_size=embedding_output_dim, hidden_size=self.recurrent_layers_sizes[0], num_layers=1,
+                   bidirectional=True,
+                   batch_first=True),
+            SelectItem(0),
+            nn.ReLU(),
+            nn.Dropout(p=self.drop_rate),
+            nn.GRU(input_size=self.recurrent_layers_sizes[0] * 2, hidden_size=self.recurrent_layers_sizes[1],
+                   num_layers=1, bidirectional=True,
+                   batch_first=True),
+            SelectItem(0),
+            nn.Dropout(p=drop_rate),
+        )
+
+        self.attention = nn.MultiheadAttention(self.recurrent_layers_sizes[1] * 2, self.n_head)
+
+        self.decoder = nn.Sequential(
+            nn.GRU(input_size=self.recurrent_layers_sizes[1] * 2, hidden_size=self.recurrent_layers_sizes[2],
+                   num_layers=1,
+                   bidirectional=False,
+                   batch_first=True),
+            SelectItem(1),
+            nn.Linear(self.recurrent_layers_sizes[2], regressor_layer_size),
+            nn.ReLU(),
+            nn.Dropout(p=self.latent_dropout_rate),
+            nn.Linear(regressor_layer_size, 1)
+        )
+
+        self.regressor = nn.Linear(self.regressor_layer_size, self.nb_aa)
+
+        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)
+
+        self.select = SelectItem(0)
+        self.regressor.float()
+        self.attention.float()
+        self.encoder.float()
+        self.decoder.float()
+        self.emb.float()
+
+    def forward(self, seq):
+        x = torch.nn.functional.one_hot(seq, self.nb_aa)
+        x_emb = self.emb(x.float())
+        x_enc = self.encoder(x_emb)
+        x_att, _ = self.attention(x_enc, x_enc, x_enc)
+        x_rt = self.decoder(x_att)
+        x_rt_flat = torch.flatten(x_rt)
+        return x_rt_flat
+
+
+class RT_pred_model_self_attention_multi_sum(nn.Module):
+
+    def __init__(self, drop_rate=0.5, embedding_output_dim=16, nb_aa=23, latent_dropout_rate=0.1,
+                 regressor_layer_size=512,
+                 recurrent_layers_sizes=(256, 512), n_head=1):
+        self.drop_rate = drop_rate
+        self.n_head = n_head
+        self.regressor_layer_size = regressor_layer_size
+        self.latent_dropout_rate = latent_dropout_rate
+        self.recurrent_layers_sizes = recurrent_layers_sizes
+        self.nb_aa = nb_aa
+        self.embedding_output_dim = embedding_output_dim
+        super(RT_pred_model_self_attention_multi_sum, self).__init__()
+        self.encoder = nn.Sequential(
+            nn.GRU(input_size=embedding_output_dim, hidden_size=self.recurrent_layers_sizes[0], num_layers=1,
+                   bidirectional=True,
+                   batch_first=True),
+            SelectItem(0),
+            nn.ReLU(),
+            nn.Dropout(p=self.drop_rate),
+            nn.GRU(input_size=self.recurrent_layers_sizes[0] * 2, hidden_size=self.recurrent_layers_sizes[1],
+                   num_layers=1, bidirectional=True,
+                   batch_first=True),
+            SelectItem(0),
+            nn.Dropout(p=drop_rate),
+        )
+
+        self.decoder = nn.Sequential(
+            SelfAttention_multi(self.recurrent_layers_sizes[1] * 2, self.n_head),
+            nn.Linear(self.recurrent_layers_sizes[1] * 2, regressor_layer_size),
+            nn.ReLU(),
+            nn.Dropout(p=self.latent_dropout_rate),
+            nn.Linear(regressor_layer_size, 1)
+        )
+
+        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)
+
+        self.encoder.float()
+        self.decoder.float()
+        self.emb.float()
+
+    def forward(self, seq):
+        x = torch.nn.functional.one_hot(seq, self.nb_aa)
+        x_emb = self.emb(x.float())
+        x_enc = self.encoder(x_emb)
+        x_rt = self.decoder(x_enc)
+        x_rt = torch.flatten(x_rt)
+        return x_rt
+
+
+class RT_pred_model_transformer(nn.Module):
+
+    def __init__(self, drop_rate=0.5, embedding_output_dim=128, nb_aa=23, latent_dropout_rate=0.1,
+                 regressor_layer_size=512, n_head=1):
+        self.drop_rate = drop_rate
+        self.n_head = n_head
+        self.regressor_layer_size = regressor_layer_size
+        self.latent_dropout_rate = latent_dropout_rate
+        self.nb_aa = nb_aa
+        self.embedding_output_dim = embedding_output_dim
+        super(RT_pred_model_transformer, self).__init__()
+        self.encoder = nn.Sequential(
+            TransformerEncoder(1, input_dim=embedding_output_dim, num_heads=self.n_head, dim_feedforward=512,
+                               dropout=self.drop_rate)
+        )
+
+        self.decoder = nn.Sequential(
+            TransformerEncoder(1, input_dim=embedding_output_dim, num_heads=self.n_head, dim_feedforward=512,
+                               dropout=self.drop_rate),
+            nn.Flatten(),
+            nn.Linear(embedding_output_dim * 30, self.regressor_layer_size),
+            nn.ReLU(),
+            nn.Dropout(p=self.latent_dropout_rate),
+            nn.Linear(self.regressor_layer_size, 1)
+        )
+
+        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)
+        self.pos_embedding = nn.Linear(30, self.embedding_output_dim)
+
+        self.pos_embedding.float()
+        self.encoder.float()
+        self.decoder.float()
+        self.emb.float()
+
+    def forward(self, seq):
+        indices = torch.tensor([i for i in range(30)])
+        indice_ohe = torch.nn.functional.one_hot(indices, 30)
+        x_ind = self.pos_embedding(indice_ohe.float())
+        x = torch.nn.functional.one_hot(seq, self.nb_aa)
+        x_emb = self.emb(x.float())
+        x_enc = self.encoder(x_emb + x_ind)
+        x_rt = self.decoder(x_enc)
+        x_rt = torch.flatten(x_rt)
+        return x_rt
+
+
+class RT_pred_model_self_attention_pretext(nn.Module):
+
+    def __init__(self, drop_rate=0.5, embedding_output_dim=16, nb_aa=23, latent_dropout_rate=0.1,
+                 regressor_layer_size=512,
+                 recurrent_layers_sizes=(256, 512), ):
+        self.drop_rate = drop_rate
+        self.regressor_layer_size = regressor_layer_size
+        self.latent_dropout_rate = latent_dropout_rate
+        self.recurrent_layers_sizes = recurrent_layers_sizes
+        self.nb_aa = nb_aa
+        self.embedding_output_dim = embedding_output_dim
+        super(RT_pred_model_self_attention_pretext, self).__init__()
+        self.encoder = nn.Sequential(
+            nn.GRU(input_size=embedding_output_dim, hidden_size=self.recurrent_layers_sizes[0], num_layers=1,
+                   bidirectional=True,
+                   batch_first=True),
+            SelectItem(0),
+            nn.ReLU(),
+            nn.Dropout(p=self.drop_rate),
+            nn.GRU(input_size=self.recurrent_layers_sizes[0] * 2, hidden_size=self.recurrent_layers_sizes[1],
+                   num_layers=1, bidirectional=True,
+                   batch_first=True),
+            SelectItem(0),
+            nn.Dropout(p=drop_rate),
+        )
+
+        self.decoder_rec = nn.Sequential(
+            nn.GRU(input_size=self.recurrent_layers_sizes[1] * 2, hidden_size=self.regressor_layer_size, num_layers=1,
+                   bidirectional=False,
+                   batch_first=True),
+            SelectItem(0),
+            nn.Dropout(p=drop_rate),
+        )
+
+        self.attention = nn.MultiheadAttention(self.recurrent_layers_sizes[1] * 2, 1)
+
+        self.decoder = nn.Sequential(
+            SelfAttention(self.recurrent_layers_sizes[1] * 2),
+            nn.Linear(self.recurrent_layers_sizes[1] * 2, regressor_layer_size),
+            nn.ReLU(),
+            nn.Dropout(p=self.latent_dropout_rate),
+            nn.Linear(regressor_layer_size, 1)
+        )
+
+        self.regressor = nn.Linear(self.regressor_layer_size, self.nb_aa)
+
+        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)
+
+        self.decoder_rec.float()
+        self.regressor.float()
+        self.attention.float()
+        self.encoder.float()
+        self.decoder.float()
+        self.emb.float()
+
+    def forward(self, seq):
+        x = torch.nn.functional.one_hot(seq, self.nb_aa)
+        x_emb = self.emb(x.float())
+        x_enc = self.encoder(x_emb)
+        x_rt = self.decoder(x_enc)
+        enc_att, _ = self.attention(x_enc, x_enc, x_enc)
+        dec_att = self.decoder_rec(enc_att)
+        seq_rec = self.regressor(dec_att)
+        x_rt = torch.flatten(x_rt)
+        return x_rt, seq_rec
+
+
+class Intensity_pred_model_multi_head(nn.Module):
+
+    def __init__(self, drop_rate=0.5, embedding_output_dim=16, nb_aa=22, latent_dropout_rate=0.1,
+                 regressor_layer_size=512,
+                 recurrent_layers_sizes=(256, 512), ):
+        self.drop_rate = drop_rate
+        self.regressor_layer_size = regressor_layer_size
+        self.latent_dropout_rate = latent_dropout_rate
+        self.recurrent_layers_sizes = recurrent_layers_sizes
+        self.nb_aa = nb_aa
+        self.embedding_output_dim = embedding_output_dim
+        super(Intensity_pred_model_multi_head, self).__init__()
+        self.seq_encoder = nn.Sequential(
+            nn.GRU(input_size=self.embedding_output_dim, hidden_size=self.recurrent_layers_sizes[0], num_layers=1,
+                   bidirectional=True, batch_first=True),
+            SelectItem(0),
+            nn.ReLU(),
+            nn.Dropout(p=drop_rate),
+            nn.GRU(input_size=self.recurrent_layers_sizes[0] * 2, hidden_size=self.recurrent_layers_sizes[1],
+                   num_layers=1, bidirectional=True,
+                   batch_first=True),
+            SelectItem(0),
+            nn.Dropout(p=drop_rate),
+        )
+
+        self.meta_enc = nn.Sequential(nn.Linear(7, self.recurrent_layers_sizes[1] * 2))
+
+        self.emb = nn.Linear(self.nb_aa, self.embedding_output_dim)
+
+        self.attention = nn.MultiheadAttention(self.recurrent_layers_sizes[1] * 2, 1)
+
+        self.decoder = nn.Sequential(
+            nn.GRU(input_size=self.recurrent_layers_sizes[1] * 2, hidden_size=self.regressor_layer_size, num_layers=1,
+                   bidirectional=False, batch_first=True),
+            SelectItem(0),
+            nn.Dropout(p=drop_rate),
+        )
+
+        self.regressor = nn.Linear(self.regressor_layer_size, 1)
+
+        # intensity range from 0 to 1 (-1 mean impossible)
+        self.meta_enc.float()
+        self.seq_encoder.float()
+        self.decoder.float()
+        self.emb.float()
+        self.attention.float()
+        self.regressor.float()
+
+    def forward(self, seq, energy, charge):
+        x = torch.nn.functional.one_hot(seq.long(), self.nb_aa)
+        x_emb = self.emb(x.float())
+        out_1 = self.seq_encoder(x_emb)
+        weight_out, _ = self.attention(out_1, out_1, out_1)
+        # metadata encoder
+        out_2 = self.meta_enc(torch.concat([charge, energy], 1))
+        out_2 = out_2.repeat(30, 1, 1)
+        out_2 = out_2.transpose(0, 1)
+        fusion_encoding = torch.mul(out_2, weight_out)
+        fusion_encoding_rep = fusion_encoding.repeat(1, 6, 1)
+        out = self.decoder(fusion_encoding_rep)
+        intensity = self.regressor(out)
+        intensity = torch.flatten(intensity, start_dim=1)
+        intensity = intensity[:, :174]
+
+        return intensity
diff --git a/model_custom.py b/model_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..96cfc64035ba8c8c40d79b94b81246c05cc54005
--- /dev/null
+++ b/model_custom.py
@@ -0,0 +1,230 @@
+import math
+import torch.nn as nn
+import torch
+from tape import TAPETokenizer
+from tape.models.modeling_bert import ProteinBertModel
+
+class PermuteLayer(nn.Module):
+    def __init__(self, dims):
+        super().__init__()
+        self.dims = dims
+
+    def forward(self, x):
+        x = torch.permute(x, self.dims)
+        return x
+
+
+class PositionalEncoding(nn.Module):
+
+    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 26):
+        super().__init__()
+        self.dropout = nn.Dropout(p=dropout)
+
+        position = torch.arange(max_len).unsqueeze(1)
+        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
+        pe = torch.zeros(max_len, 1, d_model)
+        pe[:, 0, 0::2] = torch.sin(position * div_term)
+        pe[:, 0, 1::2] = torch.cos(position * div_term)
+        self.register_buffer('pe', pe)
+
+    def forward(self, x):
+        x = torch.permute(x, (1, 0, 2))
+        """
+        Arguments:
+            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
+        """
+        x = x + self.pe[:x.size(0)]
+        return self.dropout(x)
+
+
+class Model_Common_Transformer(nn.Module):
+
+    def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=21,
+                 regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512,
+                 n_head=1, seq_length=25,
+                 charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1,
+                 decoder_int_num_layer=1, acti='relu', norm=False):
+        self.charge_max = charge_max
+        self.seq_length = seq_length
+        self.nb_aa = nb_aa
+        self.charge_frag_max = charge_frag_max
+        self.n_head = n_head
+        self.embedding_dim = embedding_dim
+        self.encoder_ff = encoder_ff
+        self.encoder_num_layer = encoder_num_layer
+        self.decoder_rt_ff = decoder_rt_ff
+        self.decoder_rt_num_layer = decoder_rt_num_layer
+        self.regressor_layer_size_rt = regressor_layer_size_rt
+        self.decoder_int_ff = decoder_int_ff
+        self.decoder_int_num_layer = decoder_int_num_layer
+        self.regressor_layer_size_int = regressor_layer_size_int
+        self.drop_rate = drop_rate
+        super(Model_Common_Transformer, self).__init__()
+
+        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
+                                                                        dim_feedforward=self.encoder_ff,
+                                                                        dropout=self.drop_rate, activation=acti,
+                                                                        norm_first=norm),
+                                             num_layers=self.encoder_num_layer)
+
+        self.meta_enc = nn.Linear(self.charge_max, self.embedding_dim)
+
+        self.decoder_RT = nn.Sequential(
+            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
+                                                             dim_feedforward=self.decoder_rt_ff,
+                                                             dropout=self.drop_rate, activation=acti, norm_first=norm),
+                                  num_layers=self.decoder_rt_num_layer),
+            PermuteLayer((1, 0, 2)),
+            nn.Flatten(),
+            nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_rt),
+            nn.ReLU(),
+            nn.Dropout(p=self.drop_rate),
+            nn.Linear(self.regressor_layer_size_rt, 1)
+        )
+
+        self.decoder_int = nn.Sequential(
+            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
+                                                             dim_feedforward=self.decoder_int_ff,
+                                                             dropout=self.drop_rate, activation=acti, norm_first=norm),
+                                  num_layers=self.decoder_int_num_layer),
+            PermuteLayer((1, 0, 2)),
+            nn.Flatten(),
+            nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_int),
+            nn.ReLU(),
+            nn.Dropout(p=self.drop_rate),
+            nn.Linear(self.regressor_layer_size_int, (self.seq_length - 1) * self.charge_frag_max * 2)
+        )
+
+        self.emb = nn.Linear(self.nb_aa, self.embedding_dim)
+
+        self.pos_embedding = PositionalEncoding(max_len=self.seq_length, dropout=self.drop_rate,
+                                                d_model=self.embedding_dim)
+
+    def forward(self, seq, charge):
+        meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
+        seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
+        emb = self.pos_embedding(self.emb(seq_emb))
+        meta_enc = self.meta_enc(meta_ohe)
+
+        enc = self.encoder(emb)
+
+        out_rt = self.decoder_RT(enc)
+        int_enc = torch.mul(enc, meta_enc)
+        out_int = self.decoder_int(int_enc)
+
+        return out_rt.flatten(), out_int
+
+    def forward_rt(self, seq):
+        seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
+        emb = self.pos_embedding(self.emb(seq_emb))
+        enc = self.encoder(emb)
+        out_rt = self.decoder_RT(enc)
+
+        return out_rt.flatten()
+
+    def forward_int(self, seq, charge):
+        meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
+        seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float()
+        emb = self.pos_embedding(self.emb(seq_emb))
+        meta_enc = self.meta_enc(meta_ohe)
+        enc = self.encoder(emb)
+        int_enc = torch.mul(enc, meta_enc)
+        out_int = self.decoder_int(int_enc)
+
+        return out_int
+
+class Model_Common_Transformer_TAPE(nn.Module):
+
+    def __init__(self, drop_rate=0.1, embedding_dim=128, nb_aa=21,
+                 regressor_layer_size_rt=512, regressor_layer_size_int=512, decoder_rt_ff=512, decoder_int_ff=512,
+                 n_head=1, seq_length=25,
+                 charge_max=4, charge_frag_max=3, encoder_ff=512, encoder_num_layer=1, decoder_rt_num_layer=1,
+                 decoder_int_num_layer=1, acti='relu', norm=False):
+        self.charge_max = charge_max
+        self.seq_length = seq_length
+        self.nb_aa = nb_aa
+        self.charge_frag_max = charge_frag_max
+        self.n_head = n_head
+        self.embedding_dim = 768
+        self.encoder_ff = encoder_ff
+        self.encoder_num_layer = encoder_num_layer
+        self.decoder_rt_ff = decoder_rt_ff
+        self.decoder_rt_num_layer = decoder_rt_num_layer
+        self.regressor_layer_size_rt = regressor_layer_size_rt
+        self.decoder_int_ff = decoder_int_ff
+        self.decoder_int_num_layer = decoder_int_num_layer
+        self.regressor_layer_size_int = regressor_layer_size_int
+        self.drop_rate = drop_rate
+        super(Model_Common_Transformer_TAPE, self).__init__()
+
+        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
+                                                                        dim_feedforward=self.encoder_ff,
+                                                                        dropout=self.drop_rate, activation=acti,
+                                                                        norm_first=norm, batch_first=True),
+                                             num_layers=self.encoder_num_layer)
+
+        self.meta_enc = nn.Linear(self.charge_max, self.embedding_dim)
+
+        self.decoder_RT = nn.Sequential(
+            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
+                                                             dim_feedforward=self.decoder_rt_ff,
+                                                             dropout=self.drop_rate, activation=acti, norm_first=norm,
+                                                             batch_first=True),
+                                  num_layers=self.decoder_rt_num_layer),
+
+            nn.Flatten(),
+            nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_rt),
+            nn.ReLU(),
+            nn.Dropout(p=self.drop_rate),
+            nn.Linear(self.regressor_layer_size_rt, 1)
+        )
+
+        self.decoder_int = nn.Sequential(
+            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.n_head,
+                                                             dim_feedforward=self.decoder_int_ff,
+                                                             dropout=self.drop_rate, activation=acti, norm_first=norm,
+                                                             batch_first=True),
+                                  num_layers=self.decoder_int_num_layer),
+            nn.Flatten(),
+            nn.Linear(self.embedding_dim * self.seq_length, self.regressor_layer_size_int),
+            nn.ReLU(),
+            nn.Dropout(p=self.drop_rate),
+            nn.Linear(self.regressor_layer_size_int, (self.seq_length - 1) * self.charge_frag_max * 2)
+        )
+
+        self.model_TAPE = ProteinBertModel.from_pretrained("./ProteinBert")
+
+
+    def forward(self, seq, charge):
+        meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
+        output = self.model_TAPE(seq)
+
+        seq_emb = output[0]
+        meta_enc = self.meta_enc(meta_ohe)
+        meta_enc = meta_enc.unsqueeze(-1).expand(-1,-1,25)
+        meta_enc = torch.permute(meta_enc,(0,2,1))
+        enc = self.encoder(seq_emb)
+        out_rt = self.decoder_RT(enc)
+        int_enc = torch.mul(enc, meta_enc)
+        out_int = self.decoder_int(int_enc)
+
+        return out_rt.flatten(), out_int
+
+    def forward_rt(self, seq):
+        output = self.model_TAPE(seq)
+        seq_emb = output[0]
+        enc = self.encoder(seq_emb)
+        out_rt = self.decoder_RT(enc)
+
+        return out_rt.flatten()
+
+    def forward_int(self, seq, charge):
+        meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float()
+        output = self.model_TAPE(seq)
+        seq_emb = output[0]
+        meta_enc = self.meta_enc(meta_ohe)
+        enc = self.encoder(seq_emb)
+        int_enc = torch.mul(enc, meta_enc)
+        out_int = self.decoder_int(int_enc)
+
+        return out_int
diff --git a/msms_processing.py b/msms_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3bbb8dd3d06f0f8da5204ce487387e3822e88e
--- /dev/null
+++ b/msms_processing.py
@@ -0,0 +1,188 @@
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.decomposition import PCA
+import json
+
+
+def load_data(msms_filet_path='data/msms.txt', score_treshold=70):
+    data = pd.read_csv(msms_filet_path, sep='\t')
+    data_compact = data[['Sequence', 'Length', 'Charge', 'Retention time', 'Score', 'Matches', 'Intensities']]
+    data_filtered = data_compact[data_compact['Score'] > score_treshold]
+    data_filtered = data_filtered[data_filtered['Length'] < 26]
+    data_filtered['Spectra'] = data_filtered.apply(lambda x: filter_intensity(x.Matches, x.Intensities), axis=1)
+    return data_filtered[['Sequence', 'Length', 'Charge', 'Retention time', 'Score', 'Spectra']]
+
+
+def convert(l):
+    return [num_int for num_str in l.split() for num_int in (
+        lambda x: [float(x.replace('[', '').replace(']', ''))] if x.replace('.', '').replace('[', '').replace(']',
+                                                                                                              '').replace(
+            'e+', '').isdigit() else [])(
+        num_str)]
+
+
+def filter_intensity(matches, int_exp):
+    frag_name = ['y1', 'y2', 'y3', 'y4', 'y5', 'y6', 'y7', 'y8', 'y9', 'y10', 'y11', 'y12', 'y13', 'y14', 'y15', 'y16',
+                 'y17', 'y18', 'y19', 'y20', 'y21', 'y22', 'y23', 'y24',
+                 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'b9', 'b10', 'b11', 'b12', 'b13', 'b14', 'b15', 'b16',
+                 'b17', 'b18', 'b19', 'b20', 'b21', 'b22', 'b23', 'b24',
+                 'y1(2+)', 'y2(2+)', 'y3(2+)', 'y4(2+)', 'y5(2+)', 'y6(2+)', 'y7(2+)', 'y8(2+)', 'y9(2+)', 'y10(2+)',
+                 'y11(2+)', 'y12(2+)', 'y13(2+)', 'y14(2+)', 'y15(2+)', 'y16(2+)',
+                 'y17(2+)', 'y18(2+)', 'y19(2+)', 'y20(2+)', 'y21(2+)', 'y22(2+)', 'y23(2+)', 'y24(2+)',
+                 'b1(2+)', 'b2(2+)', 'b3(2+)', 'b4(2+)', 'b5(2+)', 'b6(2+)', 'b7(2+)', 'b8(2+)', 'b9(2+)', 'b10(2+)',
+                 'b11(2+)', 'b12(2+)', 'b13(2+)', 'b14(2+)', 'b15(2+)', 'b16(2+)',
+                 'b17(2+)', 'b18(2+)', 'b19(2+)', 'b20(2+)', 'b21(2+)', 'b22(2+)', 'b23(2+)', 'b24(2+)',
+                 'y1(3+)', 'y2(3+)', 'y3(3+)', 'y4(3+)', 'y5(3+)', 'y6(3+)', 'y7(3+)', 'y8(3+)', 'y9(3+)',
+                 'y10(3+)', 'y11(3+)', 'y12(3+)', 'y13(3+)', 'y14(3+)', 'y15(3+)', 'y16(3+)', 'y17(3+)', 'y18(3+)',
+                 'y19(3+)', 'y20(3+)', 'y21(3+)', 'y22(3+)', 'y23(3+)',
+                 'y24(3+)',
+                 'b1(3+)', 'b2(3+)', 'b3(3+)', 'b4(3+)', 'b5(3+)', 'b6(3+)', 'b7(3+)', 'b8(3+)', 'b9(3+)', 'b10(3+)',
+                 'b11(3+)', 'b12(3+)', 'b13(3+)', 'b14(3+)', 'b15(3+)', 'b16(3+)',
+                 'b17(3+)', 'b18(3+)', 'b19(3+)', 'b20(3+)', 'b21(3+)', 'b22(3+)', 'b23(3+)', 'b24(3+)'
+                 ]
+    intensity = np.zeros(len(frag_name))
+    matches = matches.split(";")
+    int_exp = int_exp.split(";")
+
+    ind1 = np.where(np.isin(matches, frag_name))[0].tolist()
+    ind2 = np.where(np.isin(frag_name, matches))[0].tolist()
+    intensity[ind2] = np.array(int_exp)[ind1]
+    return intensity
+
+
+def merge_dataset_by_name(list_names):
+    list_dataset = []
+    for name in list_names:
+        list_dataset.append(pd.read_csv(name))
+    merged_dataset = pd.concat(list_dataset)
+    return merged_dataset
+
+import matplotlib.pyplot as plt
+
+def mscatter(x,y, ax=None, m=None, **kw):
+    import matplotlib.markers as mmarkers
+    ax = ax or plt.gca()
+    sc = ax.scatter(x,y,**kw)
+    if (m is not None) and (len(m)==len(x)):
+        paths = []
+        for marker in m:
+            if isinstance(marker, mmarkers.MarkerStyle):
+                marker_obj = marker
+            else:
+                marker_obj = mmarkers.MarkerStyle(marker)
+            path = marker_obj.get_path().transformed(
+                        marker_obj.get_transform())
+            paths.append(path)
+        sc.set_paths(paths)
+    return sc
+
+if __name__ == '__main__':
+    pass
+    # data_2 = load_data('data/Custom_dataset/msmsHBM_UCGTs.txt', i)
+    # data_1 = load_data('data/Custom_dataset/msmsHBM_P450s.txt', i)
+    # data_3 = load_data('data/Custom_dataset/msmsMkBM_P450s.txt', i)
+    # data = pd.concat([data_1, data_2, data_3], ignore_index=True)
+    # err_rt = []
+    # err_spec = []
+    # nb_data = []
+    # nb_gr = []
+    # for i in range(0, 120, 5):
+    #     print(i)
+    #     data_2 = load_data('data/Custom_dataset/msmsHBM_UCGTs.txt', i)
+    #     data_1 = load_data('data/Custom_dataset/msmsHBM_P450s.txt', i)
+    #     data_3 = load_data('data/Custom_dataset/msmsMkBM_P450s.txt', i)
+    #     data = pd.concat([data_1, data_2, data_3], ignore_index=True)
+    #     groups = data.groupby('Sequence')
+    #     avg_err = 0
+    #     avg_cosim = 0
+    #     nb_data.append(len(data))
+    #     nb_gr.append(len(groups))
+    #     for seq, gr in groups:
+    #         mean = gr['Retention time'].mean()
+    #         mean_spec = np.mean(gr['Spectra'], axis=0)
+    #         cos_sim = gr['Spectra'].apply(
+    #             lambda x: (np.dot(x, mean_spec) / (np.linalg.norm(x) * np.linalg.norm(mean_spec)))).mean()
+    #         err = abs(gr['Retention time'] - mean).mean()
+    #         avg_err += err * gr.shape[0]
+    #         avg_cosim += cos_sim * gr.shape[0]
+    #     avg_err = avg_err / data.shape[0]
+    #     avg_cosim = avg_cosim / data.shape[0]
+    #     err_rt.append(avg_err)
+    #     err_spec.append(avg_cosim)
+    #
+    # fig, axs = plt.subplots(2, 2)
+    # axs[0, 0].scatter(range(0, 120, 5),  1 -2 * np.arccos(err_spec)/np.pi)
+    # axs[1, 0].scatter(range(0, 120, 5), err_rt)
+    # axs[0, 1].scatter(range(0, 120, 5), nb_gr)
+    # axs[1, 1].scatter(range(0, 120, 5), nb_data)
+    # axs[0, 0].set_title('spectral angle')
+    # axs[1, 0].set_title('avg rt err')
+    # axs[0, 1].set_title('nb groupes')
+    # axs[1, 1].set_title('nb data')
+    # plt.savefig('fig_data_full.png')
+    # data_1 = load_data('data/Custom_dataset/msms_MsBM_UGTs.txt', 55)
+    # data_2 = load_data('data/Custom_dataset/msmsHBM_UCGTs.txt', 55)
+    # data_3 = load_data('data/Custom_dataset/msmsMkBM_P450s.txt', 55)
+    # data = pd.concat([data_1, data_2, data_3], ignore_index=True)
+
+
+    # data_2 = load_data('data/Custom_dataset/msmsHBM_UCGTs.txt', 55)
+    # data_1 = load_data('data/Custom_dataset/msmsHBM_P450s.txt', 55)
+    # data_3 = load_data('data/Custom_dataset/msmsMkBM_P450s.txt', 55)
+    # data = pd.concat([data_1, data_2, data_3], ignore_index=True)
+    # data.to_csv('database/data_3_first_55.csv')
+
+    # pd.to_pickle(data,"database/data_all_type.pkl")
+    # data = pd.read_pickle("database/data_all_type.pkl")
+    # data['number comp'] = data['Spectra'].apply(lambda x: np.sum(x > 0.1))
+    # sizes_gr = []
+    #
+    # groups = data.groupby('Sequence')
+    # for seq, gr in groups:
+    #     groups_2 = gr.groupby('Charge')
+    #     for ch,gr_2 in groups_2 :
+    #         array = np.stack(gr_2['Spectra'])
+    #         sizes_gr.append(array.shape[0])
+    #         if array.shape[0] > 10:
+    #
+    #             standardized_data = (array - array.mean(axis=0)) / array.std(axis=0)
+    #             standardized_data = np.nan_to_num(standardized_data)
+    #             covariance_matrix = np.cov(standardized_data, ddof=1, rowvar=False, dtype=np.float32)
+    #
+    #             eigenvalues, eigenvectors = np.linalg.eig(covariance_matrix)
+    #
+    #             # np.argsort can only provide lowest to highest; use [::-1] to reverse the list
+    #             order_of_importance = np.argsort(eigenvalues)[::-1]
+    #
+    #             # utilize the sort order to sort eigenvalues and eigenvectors
+    #             sorted_eigenvalues = eigenvalues[order_of_importance]
+    #             sorted_eigenvectors = eigenvectors[:, order_of_importance]  # sort the columns
+    #             # use sorted_eigenvalues to ensure the explained variances correspond to the eigenvectors
+    #             explained_variance = sorted_eigenvalues / np.sum(sorted_eigenvalues)
+    #
+    #             k = 2  # select the number of principal components
+    #             reduced_data = np.matmul(standardized_data, sorted_eigenvectors[:, :k])  # transform the original data
+    #
+    #             ### Step 7: Determine the Explained Variance
+    #             total_explained_variance = sum(explained_variance[:k])
+    #
+    #             ### Potential Next Steps: Iterate on the Number of Principal Components
+    #             cm = plt.cm.get_cmap('RdYlBu')
+    #             f, ( ax3, ax2) = plt.subplots(1, 2)
+    #             # sc = mscatter(reduced_data[:, 0], reduced_data[:, 1], ax = ax1,  m = mark[gr['Charge']], s=15, c=gr['Score'], vmin=50, vmax=250, cmap=cm)
+    #             # ax1.text(0.25, 0, total_explained_variance,
+    #             #          horizontalalignment='center',
+    #             #          verticalalignment='center',
+    #             #          transform=ax1.transAxes)
+    #             # f.colorbar(sc, ax=ax1)
+    #
+    #             sc2 = mscatter(reduced_data[:, 0], reduced_data[:, 1], ax = ax2, s=5, c=gr_2['number comp'],
+    #                               vmin=min(gr['number comp']), vmax=max(gr['number comp']), cmap=cm)
+    #             f.colorbar(sc2, ax=ax2)
+    #
+    #             sc3 = mscatter(reduced_data[:, 0], reduced_data[:, 1], ax = ax3, s=5, c=gr_2['Score'], vmin=min(gr_2['Score']), vmax=max(gr_2['Score']), cmap=cm)
+    #             f.colorbar(sc3, ax=ax3)
+    #
+    #             plt.savefig('fig/pca/pca_' + seq + '_' + str(ch) + '.png')
+    #             plt.close()
diff --git a/mzimage_viz.py b/mzimage_viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..98a3d03fad9cda26c027f93d2fdd8123db00f099
--- /dev/null
+++ b/mzimage_viz.py
@@ -0,0 +1,20 @@
+import nibabel as nib
+import numpy as np
+
+
+from nilearn import plotting
+
+data = np.load('data/mz_image/Staph140.npy')
+
+new_image = nib.Nifti1Image(data, affine=np.eye(4))
+
+
+
+plotting.plot_stat_map(
+    new_image,
+    bg_img=None,
+    display_mode="x",
+    cut_coords=1,
+    threshold=4,
+    title="Test",
+)
\ No newline at end of file
diff --git a/mzml_exploration.py b/mzml_exploration.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e6346493f931b95da7306429950dbf276720f0e
--- /dev/null
+++ b/mzml_exploration.py
@@ -0,0 +1,305 @@
+import pyopenms as oms
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.colors as colors
+from PIL import Image
+
+def plot_spectra_2d(exp, ms_level=1, marker_size=5):
+    exp.updateRanges()
+    for spec in exp:
+        if spec.getMSLevel() == ms_level:
+            mz, intensity = spec.get_peaks()
+            p = intensity.argsort()  # sort by intensity to plot highest on top
+            rt = np.full([mz.shape[0]], spec.getRT(), float)
+            plt.scatter(
+                rt,
+                mz[p],
+                c=intensity[p],
+                cmap="afmhot_r",
+                s=marker_size,
+                norm=colors.LogNorm(
+                    exp.getMinIntensity() + 1, exp.getMaxIntensity()
+                ),
+            )
+    plt.clim(exp.getMinIntensity() + 1, exp.getMaxIntensity())
+    plt.xlabel("time (s)")
+    plt.ylabel("m/z")
+    plt.colorbar()
+    plt.show()  # slow for larger data sets
+
+
+def count_data_points(exp):
+    s = exp.getSpectra()
+    c = 0
+    for i in range(len(s)):
+        c += s[i].size()
+    return c
+
+
+def reconstruct_spectra(exp, ind):
+    a1 = exp.getChromatograms()[1]
+    ref = exp.getSpectrum(ind)
+    rt1 = ref.getRT()
+    rt2 = exp.getSpectrum(ind + 1).getRT()
+    peaks = a1.get_peaks()
+    data = peaks[1][(rt1 <= peaks[0]) & (peaks[0] <= rt2)]
+    return data, ref.get_peaks()
+
+def build_image(e, bin_mz):
+    e.updateRanges()
+    id = e.getSpectra()[-1].getNativeID()
+    dico = dict(s.split('=', 1) for s in id.split())
+    max_cycle = int(dico['cycle'])
+    list_cycle = [[] for _ in range(max_cycle)]
+
+    for s in e:
+        if s.getMSLevel() == 2:
+            ms2_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
+            ms2_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
+            break
+
+    for s in e:
+        if s.getMSLevel() == 1:
+            ms1_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
+            ms1_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
+            break
+
+    total_ms2_mz = ms2_end_mz - ms2_start_mz
+    n_bin_ms2 = int(total_ms2_mz // bin_mz) + 1
+    size_bin_ms2 = total_ms2_mz / n_bin_ms2
+
+    total_ms1_mz = ms1_end_mz - ms1_start_mz
+    n_bin_ms1 = 100  # pour l'instant
+    size_bin_ms1 = total_ms1_mz / n_bin_ms1
+
+    for spec in e:  # data structure
+        id = spec.getNativeID()
+        dico = dict(s.split('=', 1) for s in id.split())
+        if spec.getMSLevel() == 2:
+            list_cycle[int(dico['cycle'])-1].append(spec)
+        if spec.getMSLevel() == 1:
+            list_cycle[int(dico['cycle'])-1].insert(0, spec)
+
+    im = np.zeros([max_cycle, 100, n_bin_ms2 + 1])
+    for c in range(max_cycle):  # Build one cycle image
+        j=0
+        chan = np.zeros([n_bin_ms1, n_bin_ms2 + 1])
+        if len(list_cycle[c])>0 :
+            if list_cycle[c][0].getMSLevel() == 1:
+                j=1
+                ms1 = list_cycle[c][0]
+                intensity = ms1.get_peaks()[1]
+                mz = ms1.get_peaks()[0]
+                for i in range(ms1.size()):
+                    chan[int((mz[i]-ms1_start_mz) // size_bin_ms1), 0] += intensity[i]
+
+        for k in range(j, len(list_cycle[c])):
+            ms2 = list_cycle[c][k]
+            intensity = ms2.get_peaks()[1]
+            mz = ms2.get_peaks()[0]
+            id = ms2.getNativeID()
+            dico = dict(s.split('=', 1) for s in id.split())
+            for i in range(ms2.size()):
+                chan[int(dico['experiment'])-2, int((mz[i]-ms2_start_mz) // size_bin_ms2)] += intensity[i]
+
+        im[c, :, :] = chan
+
+    return im
+
+def build_image_generic(e, bin_mz, num_RT):
+    e.updateRanges()
+    list_cycle = []
+
+    for s in e:
+        if s.getMSLevel() == 2:
+            ms2_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
+            ms2_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
+            break
+
+    for s in e:
+        if s.getMSLevel() == 1:
+            ms1_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
+            ms1_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
+            break
+
+    total_ms2_mz = ms2_end_mz - ms2_start_mz + 10
+    n_bin_ms2 = int(total_ms2_mz // bin_mz) + 1
+    size_bin_ms2 = total_ms2_mz / n_bin_ms2
+
+    total_ms1_mz = ms1_end_mz - ms1_start_mz + 10
+    n_bin_ms1 = 100  # pour l'instant
+    size_bin_ms1 = total_ms1_mz / n_bin_ms1
+
+    cycle = -1
+    for spec in e:  # data structure
+        if spec.getMSLevel() == 1:
+            cycle += 1
+            list_cycle.append([])
+            list_cycle[cycle].insert(0, spec)
+        if spec.getMSLevel() == 2:
+            try :
+                list_cycle[cycle].append(spec)
+            except :
+                list_cycle.append([])
+                list_cycle[cycle].append(spec)
+    max_cycle = len(list_cycle)
+    total_by_window = max_cycle//num_RT + 1
+    experiment_max = len(list_cycle[-2])-1
+    im = np.zeros([max_cycle, experiment_max, n_bin_ms2 + 1])
+    for c in range(0, max_cycle, num_RT):  # Build one cycle image
+        j=0
+        experiment = 0
+        chan = np.zeros([experiment_max, n_bin_ms2 + 1])
+        if len(list_cycle[c])>0 :
+            if list_cycle[c][0].getMSLevel() == 1:
+                j=1
+                pass
+
+        for k in range(j, len(list_cycle[c])):
+            for n in range(num_RT):
+                ms2 = list_cycle[c][k+n]
+                intensity = ms2.get_peaks()[1]
+                mz = ms2.get_peaks()[0]
+
+                for i in range(ms2.size()):
+                    chan[experiment, int((mz[i]-ms2_start_mz) // size_bin_ms2)] += intensity[i]
+                experiment +=1
+        im[c, :, :] = chan
+
+    return im
+
+
+def build_image_frag(e, bin_mz):
+    e.updateRanges()
+    id = e.getSpectra()[-1].getNativeID()
+
+    dico = dict(s.split('=', 1) for s in id.split())
+    max_cycle = int(dico['cycle'])
+    list_cycle = [[] for _ in range(max_cycle)]
+
+    for s in e:
+        if s.getMSLevel() == 2:
+            ms2_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
+            ms2_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
+            break
+
+    for s in e:
+        if s.getMSLevel() == 1:
+            ms1_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
+            ms1_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
+            break
+
+    total_ms2_mz = ms2_end_mz - ms2_start_mz
+    n_bin_ms2 = int(total_ms2_mz // bin_mz) + 1
+    size_bin_ms2 = total_ms2_mz / n_bin_ms2
+
+    total_ms1_mz = ms1_end_mz - ms1_start_mz
+    n_bin_ms1 = 100  # pour l'instant
+    size_bin_ms1 = total_ms1_mz / n_bin_ms1
+    for spec in e:  # data structure
+        id = spec.getNativeID()
+        dico = dict(s.split('=', 1) for s in id.split())
+        if spec.getMSLevel() == 2:
+            list_cycle[int(dico['cycle'])-1].append(spec)
+        if spec.getMSLevel() == 1:
+            list_cycle[int(dico['cycle'])-1].insert(0, spec)
+
+    im = np.zeros([max_cycle, 100, n_bin_ms2])
+
+    for c in range(max_cycle):  # Build one cycle image
+        j=0
+        chan = np.zeros([n_bin_ms1, n_bin_ms2])
+        if len(list_cycle[c])>0 :
+            if list_cycle[c][0].getMSLevel() == 1:
+                j = 1
+                pass
+        for k in range(j, len(list_cycle[c])):
+            ms2 = list_cycle[c][k]
+            intensity = ms2.get_peaks()[1]
+            mz = ms2.get_peaks()[0]
+            id = ms2.getNativeID()
+            dico = dict(s.split('=', 1) for s in id.split())
+            for i in range(ms2.size()):
+                chan[int(dico['experiment'])-2, int((mz[i]-ms2_start_mz) // size_bin_ms2)] += intensity[i]
+
+        im[c, :, :] = chan
+
+    return im
+
+
+def check_windows(e):
+    e.updateRanges()
+    id = e.getSpectra()[-1].getNativeID()
+    dico = dict(s.split('=', 1) for s in id.split())
+    max_cycle = int(dico['cycle'])
+    list_cycle = [[] for _ in range(max_cycle)]
+
+    for s in e:
+        if s.getMSLevel() == 2:
+            ms2_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
+            ms2_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
+            break
+
+    for s in e:
+        if s.getMSLevel() == 1:
+            ms1_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
+            ms1_end_mz = s.getInstrumentSettings().getScanWindows()[0].end
+            break
+
+    for spec in e:  # data structure
+        id = spec.getNativeID()
+        dico = dict(s.split('=', 1) for s in id.split())
+        if spec.getMSLevel() == 2:
+            list_cycle[int(dico['cycle'])-1].append(spec)
+        if spec.getMSLevel() == 1:
+            list_cycle[int(dico['cycle'])-1].insert(0, spec)
+
+    res = []
+
+    for c in range(max_cycle):
+        res.append([])
+        for k in range(0, len(list_cycle[c])):
+            spec = list_cycle[c][k]
+            if spec.getMSLevel() == 2:
+                b = spec.getPrecursors()
+                res[-1].append(b[0].getMZ() - b[0].getIsolationWindowLowerOffset())
+                res[-1].append(b[0].getMZ() + b[0].getIsolationWindowUpperOffset())
+    return res
+
+def check_energy(im):
+    len_RT = im.shape[0]
+    len_frag = im.shape[1]
+    len_3 = im.shape[2]
+    l = np.zeros((len_RT,len_frag))
+    for i in range(len_RT):
+        for f in range(len_frag):
+            frag = im[i,f,1:len_3].sum()
+            prec = im[i,f,0]
+            if prec != 0 :
+                l[i,f]=frag/prec
+    return l
+
+if __name__ == "__main__":
+    e = oms.MSExperiment()
+    oms.MzMLFile().load("data/STAPH140.mzML", e)
+    im = build_image_frag(e, 2)
+    im2 = np.maximum(0,np.log(im+1))
+    np.save('data/mz_image/Staph140.npy',im2)
+
+    # norm = np.max(im2)
+    # for i in range(im.shape[0]) :
+    #     mat = im2[i, :, :]
+    #     img = Image.fromarray(mat / norm)
+    #     img.save('fig/mzimage/RT_frag_'+str(i)+'.tif')
+    # res = check_windows(e)
+    #
+    # max_len = np.array([len(array) for array in res]).max()
+    #
+    # # What value do we want to fill it with?
+    # default_value = 0
+    #
+    # b = [np.pad(array, (0, max_len - len(array)), mode='constant', constant_values=default_value) for array in res]
+
+
+
+
diff --git a/prosit_data_merge.py b/prosit_data_merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..03f47f0c09ece7590541dd57df27d66d33b90558
--- /dev/null
+++ b/prosit_data_merge.py
@@ -0,0 +1,117 @@
+import pandas as pd
+import numpy as np
+from common_dataset import Common_Dataset
+from dataloader import load_intensity_df_from_files
+
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader
+import pandas as pd
+
+ALPHABET_UNMOD = {
+    "_": 0,
+    "A": 1,
+    "C": 2,
+    "D": 3,
+    "E": 4,
+    "F": 5,
+    "G": 6,
+    "H": 7,
+    "I": 8,
+    "K": 9,
+    "L": 10,
+    "M": 11,
+    "N": 12,
+    "P": 13,
+    "Q": 14,
+    "R": 15,
+    "S": 16,
+    "T": 17,
+    "V": 18,
+    "W": 19,
+    "Y": 20,
+    "CaC": 21,
+    "OxM": 22
+}
+
+
+def padding(dataframe, columns, length):
+    def pad(x):
+        return x + (length - len(x) + 2 * x.count('-')) * '_'
+
+    for i in range(len(dataframe)):
+        print(i)
+        if len(dataframe[columns][i]) > length + 2 * dataframe[columns][i].count('-'):
+            dataframe.drop(i)
+
+    dataframe[columns] = dataframe[columns].map(pad)
+
+def zero_to_minus(arr):
+    arr[arr <= 0.00001] = -1.
+    return arr
+
+def alphabetical_to_numerical(seq):
+    num = []
+    dec = 0
+    for i in range(len(seq) - 2 * seq.count('-')):
+        if seq[i+dec] != '-':
+            num.append(ALPHABET_UNMOD[seq[i+dec]])
+        else:
+            if seq[i + dec + 1:i + dec + 4] == 'CaC':
+                num.append(21)
+            elif seq[i + dec + 1:i + dec + 4] == 'OxM':
+                num.append(22)
+            else:
+                raise 'Modification not supported'
+            dec += 4
+    return np.array(num)
+
+
+sources = ('data/intensity/sequence_train.npy',
+                 'data/intensity/intensity_train.npy',
+                 'data/intensity/collision_energy_train.npy',
+                 'data/intensity/precursor_charge_train.npy')
+
+
+data_rt = pd.read_csv('database/data_unique_ptms.csv')
+data_rt['Sequence']=data_rt['mod_sequence']
+
+padding(data_rt, 'Sequence', 30)
+data_rt['Sequence'] = data_rt['Sequence'].map(alphabetical_to_numerical)
+
+data_rt =data_rt.drop(columns='mod_sequence')
+
+data_int = load_intensity_df_from_files(sources[0], sources[1], sources[2], sources[3])
+
+seq_rt = data_rt.Sequence
+seq_int = data_int.seq
+seq_rt = seq_rt.tolist()
+seq_int = seq_int.tolist()
+seq_rt = [tuple(l) for l in seq_rt]
+seq_int = [tuple(l) for l in seq_int]
+
+ind_dict_rt = dict((k, i) for i, k in enumerate(seq_rt))
+inter = set(ind_dict_rt).intersection(seq_int)
+ind_dict_rt = [ind_dict_rt[x] for x in inter]
+
+
+data_int.irt = np.zeros(data_int.energy.shape)
+
+i=0
+for ind in ind_dict_rt :
+    print(i,'/',len(ind_dict_rt))
+    i+=1
+    ind_int = [k for k, x in enumerate(seq_int) if x == seq_rt[ind]]
+    data_int.irt[ind_int] = data_rt.irt[ind]
+
+np.save('data/intensity/collision_irt_train.npy',data_int.irt)
+
+# indices_common = dict((k, i) for i, k in enumerate(seq_int))
+# indices_common = [indices_common[x] for x in inter]
+#
+
+# data_int.irt[indices_common] = data_rt.irt[ind_dict_rt]
+#
+
+
diff --git a/prosit_ori_callback.py b/prosit_ori_callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0e847ee0ac7ccabd4e0ebc67fc56ea868de8147
--- /dev/null
+++ b/prosit_ori_callback.py
@@ -0,0 +1,142 @@
+import os
+
+import numpy as np
+import tensorflow
+from dlomix.data import RetentionTimeDataset
+from dlomix.eval import TimeDeltaMetric
+from dlomix.models import PrositRetentionTimePredictor
+from matplotlib import pyplot as plt
+from sklearn.metrics import r2_score
+from keras import backend as K
+
+
+def save_reg(pred, true, name):
+    r2 = round(r2_score(true, pred), 4)
+    fig = plt.figure()
+    ax = fig.add_subplot(1, 1, 1)
+    ax.plot(true, pred, 'y,')
+    ax.text(120, 20, 'R² = ' + str(r2), fontsize=12)
+    ax.set_xlabel("True")
+    ax.set_ylabel("Pred")
+    ax.set_xlim([-50, 200])
+    ax.set_ylim([-50, 200])
+    plt.savefig(name)
+    plt.clf()
+
+
+def save_evol(pred_prev, pred, true, name):
+    fig = plt.figure()
+    ax = fig.add_subplot(1, 1, 1)
+    ax.plot(pred - true, pred_prev - true, 'y,')
+    ax.set_xlabel("Current error")
+    ax.set_ylabel("Previous error")
+    ax.set_xlim([-50, 50])
+    ax.set_ylim([-50, 50])
+    plt.savefig(name)
+    plt.clf()
+
+
+class GradientCallback(tensorflow.keras.callbacks.Callback):
+    console = True
+
+    def on_epoch_end(self, epoch, logs=None, evol=True, reg=True):
+        with tensorflow.GradientTape() as tape:
+            for f, y in rtdata.train_data:
+                features, y_true = f,y
+                break
+            y_pred = self.model(features)  # forward-propagation
+            loss = self.model.compiled_loss(y_true=y_true, y_pred=y_pred)  # calculate loss
+            gradients = tape.gradient(loss, self.model.trainable_weights)
+        for weights, grads in zip(self.model.trainable_weights, gradients):
+            tensorflow.summary.histogram(
+                weights.name.replace(':', '_') + '_grads', data=grads, step=epoch, buckets=100)
+        preds = self.model.predict(test_rtdata.test_data)
+        if reg :
+            save_reg(preds.flatten(), test_targets, 'fig/unstability/reg_epoch_'+str(epoch))
+        if evol :
+            if epoch >0 :
+                pred_prev = np.load('temp/mem_pred.npy')
+                save_evol(pred_prev, preds.flatten(), test_targets, 'fig/evol/reg_epoch_'+str(epoch))
+        np.save('temp/mem_pred.npy', preds.flatten())
+
+
+def lr_warmup_cosine_decay(global_step,
+                           warmup_steps,
+                           hold=0,
+                           total_steps=0,
+                           start_lr=0.0,
+                           target_lr=1e-3):
+    # Cosine decay
+    learning_rate = 0.5 * target_lr * (
+            1 + np.cos(np.pi * (global_step - warmup_steps - hold) / float(total_steps - warmup_steps - hold)))
+
+    # Target LR * progress of warmup (=1 at the final warmup step)
+    warmup_lr = target_lr * (global_step / warmup_steps)
+
+    # Choose between `warmup_lr`, `target_lr` and `learning_rate` based on whether `global_step < warmup_steps` and we're still holding.
+    # i.e. warm up if we're still warming up and use cosine decayed lr otherwise
+    if hold > 0:
+        learning_rate = np.where(global_step > warmup_steps + hold,
+                                 learning_rate, target_lr)
+
+    learning_rate = np.where(global_step < warmup_steps, warmup_lr, learning_rate)
+    return learning_rate
+
+
+class WarmupCosineDecay(tensorflow.keras.callbacks.Callback):
+    def __init__(self, total_steps=0, warmup_steps=0, start_lr=0.0, target_lr=1e-3, hold=0):
+        super(WarmupCosineDecay, self).__init__()
+        self.start_lr = start_lr
+        self.hold = hold
+        self.total_steps = total_steps
+        self.global_step = 0
+        self.target_lr = target_lr
+        self.warmup_steps = warmup_steps
+        self.lrs = []
+
+    def on_batch_end(self, batch, logs=None):
+        self.global_step = self.global_step + 1
+        lr = model.optimizer.lr.numpy()
+        self.lrs.append(lr)
+
+    def on_batch_begin(self, batch, logs=None):
+        lr = lr_warmup_cosine_decay(global_step=self.global_step,
+                                    total_steps=self.total_steps,
+                                    warmup_steps=self.warmup_steps,
+                                    start_lr=self.start_lr,
+                                    target_lr=self.target_lr,
+                                    hold=self.hold)
+        K.set_value(self.model.optimizer.lr, lr)
+
+
+if __name__ == '__main__':
+    try:
+        os.mkdir("./metrics_lr")
+    except:
+        pass
+
+    try:
+        os.mkdir("./logs_lr")
+    except:
+        pass
+
+    BATCH_SIZE = 1024
+    rtdata = RetentionTimeDataset(data_source='database/data_train.csv', sequence_col='sequence',
+                                  seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False)
+    test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
+                                       seq_length=30, batch_size=BATCH_SIZE, test=True)
+    test_targets = test_rtdata.get_split_targets(split="test")
+    np.save('results/pred_prosit_ori/target.npy', test_targets)
+    model = PrositRetentionTimePredictor(seq_length=30)
+    model.compile(optimizer='adam',
+                  loss='mse',
+                  metrics=['mean_absolute_error', TimeDeltaMetric()])
+    file_writer = tensorflow.summary.create_file_writer("./metrics_prosit")
+    file_writer.set_as_default()
+
+    # write_grads has been removed
+    gradient_cb = GradientCallback()
+    # tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir="./logs_prosit")
+    # lr_callback = WarmupCosineDecay(total_steps=100, warmup_steps=10, start_lr=0.0, target_lr=1e-3, hold=5)
+    model.fit(rtdata.train_data, epochs=100, batch_size=BATCH_SIZE, validation_data=rtdata.val_data,
+              callbacks=[gradient_cb])
diff --git a/prosit_ori_tf.py b/prosit_ori_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..b78fb9cb58a781277a4fa3443e52699e52021b63
--- /dev/null
+++ b/prosit_ori_tf.py
@@ -0,0 +1,174 @@
+import os
+import datetime
+
+from sklearn.metrics import r2_score
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import wandb
+from dlomix.data import RetentionTimeDataset
+from dlomix.eval import TimeDeltaMetric
+from dlomix.models import PrositRetentionTimePredictor, RetentionTimePredictor
+from dlomix.reports import RetentionTimeReport
+import tensorflow
+
+
+def save_reg(pred, true, name):
+    coef = np.polyfit(pred, true, 1)
+    poly1d_fn = np.poly1d(coef)
+    r2 = round(r2_score(pred, true), 4)
+    plt.plot(pred, true, 'y,', pred, poly1d_fn(pred), '--k')
+    plt.text(120, 20, 'R² = ' + str(r2), fontsize=12)
+    plt.savefig(name)
+    plt.clf()
+
+
+def track_train(model, epoch, test_rtdata, rtdata):
+    BATCH_SIZE = 256
+    test_targets = test_rtdata.get_split_targets(split="test")
+    train_target = rtdata.get_split_targets(split="train")
+    loss = tensorflow.keras.losses.MeanSquaredError()
+    metric = TimeDeltaMetric()
+    optimizer = tensorflow.keras.optimizers.Adam()
+
+    os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
+
+    os.environ["WANDB_MODE"] = "offline"
+    os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
+
+    wandb.init(project="Prosit ori full dataset", dir='./wandb_run', name='prosit ori')
+
+    for e in range(epoch):
+        for step, (X_batch, y_batch) in enumerate(rtdata.train_data):
+            with tensorflow.GradientTape() as tape:
+                predictions = model(X_batch, training=True)
+                l = loss(predictions, y_batch)
+            grads = tape.gradient(l, model.trainable_weights)
+            optimizer.apply_gradients(zip(grads, model.trainable_weights))
+            wandb.log({'grads': grads})
+        predictions = model.predict(test_rtdata.test_data)
+        save_reg(predictions.flatten(), test_targets, 'fig/unstability/reg_epoch_' + str(e))
+
+    wandb.finish()
+
+
+def train_step(model, optimizer, x_train, y_train, step):
+    with tensorflow.GradientTape() as tape:
+        predictions = model(x_train, training=True)
+        tape.watch(model.trainable_variables)
+        loss = loss_object(y_train, predictions)
+    grads = tape.gradient(loss, model.trainable_variables)
+    optimizer.apply_gradients(zip(grads, model.trainable_variables))
+    for weights, grads in zip(model.trainable_weights, grads):
+        tensorflow.summary.histogram(
+            weights.name.replace(':', '_') + '_grads', data=grads, step=step)
+    train_loss(loss)
+    train_accuracy(y_train, predictions)
+
+
+def test_step(model, x_test, y_test):
+    predictions = model(x_test)
+    loss = loss_object(y_test, predictions)
+
+    test_loss(loss)
+    test_accuracy(y_test, predictions)
+
+
+def main():
+    BATCH_SIZE = 256
+
+    rtdata = RetentionTimeDataset(data_source='database/data_train.csv',
+                                  seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False)
+    test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
+                                       seq_length=30, batch_size=32, test=True)
+    test_targets = test_rtdata.get_split_targets(split="test")
+    model = PrositRetentionTimePredictor(seq_length=30)
+
+    model.compile(optimizer='adam',
+                  loss='mse',
+                  metrics=['mean_absolute_error', TimeDeltaMetric()])
+
+    os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
+
+    os.environ["WANDB_MODE"] = "offline"
+    os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
+
+    wandb.init(project="Prosit ori full dataset", dir='./wandb_run', name='prosit ori')
+
+    history = model.fit(rtdata.train_data,
+                        validation_data=rtdata.val_data,
+                        epochs=100)
+
+    wandb.finish()
+    test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
+                                       seq_length=30, batch_size=32, test=True)
+
+    predictions = model.predict(test_rtdata.test_data)
+    test_targets = test_rtdata.get_split_targets(split="test")
+
+    report = RetentionTimeReport(output_path="./output", history=history)
+
+
+def main_track():
+    BATCH_SIZE = 256
+
+    rtdata = RetentionTimeDataset(data_source='database/data_train.csv',
+                                  seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False)
+    test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
+                                       seq_length=30, batch_size=BATCH_SIZE, test=True)
+    test_targets = test_rtdata.get_split_targets(split="test")
+    model = RetentionTimePredictor(seq_length=30)
+    track_train(model, 100, test_rtdata, rtdata)
+
+
+if __name__ == '__main__':
+    # loss_object = tensorflow.keras.losses.MeanSquaredError()
+    # optimizer = tensorflow.keras.optimizers.Adam()
+    # train_loss = tensorflow.keras.metrics.Mean('train_loss', dtype=tensorflow.float32)
+    # train_accuracy = tensorflow.keras.metrics.MeanAbsoluteError('train_accuracy')
+    # test_loss = tensorflow.keras.metrics.Mean('test_loss', dtype=tensorflow.float32)
+    # test_accuracy = tensorflow.keras.metrics.MeanAbsoluteError('test_accuracy')
+    #
+    # current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+    # train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
+    # test_log_dir = 'logs/gradient_tape/' + current_time + '/test'
+    # train_summary_writer = tensorflow.summary.create_file_writer(train_log_dir)
+    # test_summary_writer = tensorflow.summary.create_file_writer(test_log_dir)
+
+    BATCH_SIZE = 256
+    rtdata = RetentionTimeDataset(data_source='database/data_train.csv',
+                                  seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False)
+    test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
+                                       seq_length=30, batch_size=32, test=True)
+    test_targets = test_rtdata.get_split_targets(split="test")
+    # model = RetentionTimePredictor(seq_length=30)
+    #
+    EPOCHS = 5
+
+    for epoch in range(EPOCHS):
+        for (x_train, y_train) in rtdata.train_data:
+            print(x_train)
+            break
+        #     train_step(model, optimizer, x_train, y_train, epoch)
+        # with train_summary_writer.as_default():
+        #     tensorflow.summary.scalar('loss', train_loss.result(), step=epoch)
+        #     tensorflow.summary.scalar('accuracy', train_accuracy.result(), step=epoch)
+        #
+        # for (x_test, y_test) in test_rtdata.test_data:
+        #     test_step(model, x_test, y_test)
+        # with test_summary_writer.as_default():
+        #     tensorflow.summary.scalar('loss', test_loss.result(), step=epoch)
+        #     tensorflow.summary.scalar('accuracy', test_accuracy.result(), step=epoch)
+        #
+        # template = 'Epoch {}, Loss: {}, Absolute Error: {}, Test Loss: {}, Test Absolute Error: {}'
+        # print(template.format(epoch + 1,
+        #                       train_loss.result(),
+        #                       train_accuracy.result(),
+        #                       test_loss.result(),
+        #                       test_accuracy.result()))
+        #
+        # # Reset metrics every epoch
+        # train_loss.reset_states()
+        # test_loss.reset_states()
+        # train_accuracy.reset_states()
+        # test_accuracy.reset_states()