diff --git a/common_dataset.py b/common_dataset.py
index 5ab8a5f21583168a1371968847032ef911f2f156..19f35c97b153cdbd3a426cda491d13c1026e6d99 100644
--- a/common_dataset.py
+++ b/common_dataset.py
@@ -119,9 +119,10 @@ def zero_to_minus(arr):
 
 class Common_Dataset(Dataset):
 
-    def __init__(self, dataframe, length, pad=True, convert=True, vocab='unmod'):
+    def __init__(self, dataframe, length, pad=True, convert=True, vocab='unmod', file=False):
         print('Data loader Initialisation')
         self.data = dataframe.reset_index()
+        self.file_mode = file
         if pad :
             print('Padding')
             padding(self.data, 'Sequence', length)
@@ -135,10 +136,16 @@ class Common_Dataset(Dataset):
         seq = self.data['Sequence'][index]
         rt = self.data['Retention time'][index]
         intensity = self.data['Spectra'][index]
-
         charge = self.data['Charge'][index]
+        file = self.data['file'][index]
+
+        if self.file_mode :
+            return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity),  torch.tensor(file)
+        else :
+            return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity)
 
-        return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity)
+    def set_file_mode(self,b):
+        self.file_mode=b
 
     def __len__(self) -> int:
         return self.data.shape[0]
diff --git a/data_viz.py b/data_viz.py
index 17edd1bea66b1fa991c3cc46ac0898cdbced5003..8b71678f6c4db51efcc7ce8cfff93820dc79e32a 100644
--- a/data_viz.py
+++ b/data_viz.py
@@ -253,23 +253,23 @@ def add_length(dataframe):
     dataframe['length']=dataframe['seq'].map(fonc)
 
 
-df = pd.read_csv('output/output_common_data_ISA.csv')
+# df = pd.read_csv('output/output_common_data_ISA.csv')
+# add_length(df)
+# df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
+# histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_ISA_ISA.png')
+# scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_ISA_ISA.png', color=True)
+# histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_ISA_ISA.png')
+#
+# df = pd.read_csv('output/out_prosit_common.csv')
+# add_length(df)
+# df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
+# histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_prosit.png')
+# scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_prosit.png', color=True)
+# histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_prosit.png')
+
+df = pd.read_csv('output/out_common_transfereval.csv')
 add_length(df)
 df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
-histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_ISA_ISA.png')
-scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_ISA_ISA.png', color=True)
-histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_ISA_ISA.png')
-
-df = pd.read_csv('output/out_prosit_common.csv')
-add_length(df)
-df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
-histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_prosit.png')
-scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_prosit.png', color=True)
-histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_prosit.png')
-
-df = pd.read_csv('output/out_common_transfer.csv')
-add_length(df)
-df['abs_error'] =  np.abs(df['rt pred']-df['true rt'])
-histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_ISA.png')
-scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_ISA.png', color=True)
-histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_ISA.png')
\ No newline at end of file
+histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_ISA_eval.png')
+scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_ISA_eval.png', color=True)
+histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_ISA_eval.png')
\ No newline at end of file
diff --git a/main_custom.py b/main_custom.py
index 16f55524b189cbe8c3f1b4bd44160c364289fc74..cd1538d15b126efd78d8a7b127d476565334ea75 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -278,8 +278,9 @@ def save_pred(model, data_val, forward, output_path):
     for param in model.parameters():
         param.requires_grad = False
     if forward == 'both':
-        pred_rt, pred_int, seqs, charges, true_rt, true_int = [], [], [], [], [], []
-        for seq, charge, rt, intensity in data_val:
+        pred_rt, pred_int, seqs, charges, true_rt, true_int, file_list = [], [], [], [], [], [], []
+        data_val.data.set_file_mode(True)
+        for seq, charge, rt, intensity, file in data_val:
             rt, intensity = rt.float(), intensity.float()
             if torch.cuda.is_available():
                 seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
@@ -290,12 +291,14 @@ def save_pred(model, data_val, forward, output_path):
             charges.extend(charge.data.cpu().tolist())
             true_rt.extend(rt.data.cpu().tolist())
             true_int.extend(intensity.data.cpu().tolist())
+            file_list.extend([file])
         data_frame['rt pred'] = pred_rt
         data_frame['seq'] = seqs
         data_frame['pred int'] = pred_int
         data_frame['true rt'] = true_rt
         data_frame['true int'] = true_int
         data_frame['charge'] = charges
+        data_frame['file'] = file_list
 
 
 
diff --git a/msms_processing.py b/msms_processing.py
index 5705fc43b573026f3e058d664e12b41bbf6e3213..b202abb2aba113030d18362ddf0b73e11ed45881 100644
--- a/msms_processing.py
+++ b/msms_processing.py
@@ -83,11 +83,17 @@ def mscatter(x,y, ax=None, m=None, **kw):
 # 17/01 23/01 24/01
 if __name__ == '__main__':
     data_1 = pd.read_pickle('database/data_DIA_16_01_aligned.pkl')
+    data_1['file']= 1
     data_2 = pd.read_pickle('database/data_DIA_17_01_aligned.pkl')
+    data_2['file'] = 2
     data_3 = pd.read_pickle('database/data_DIA_20_01_aligned.pkl')
+    data_3['file'] = 3
     data_4 = pd.read_pickle('database/data_DIA_23_01_aligned.pkl')
+    data_4['file'] = 4
     data_5 = pd.read_pickle('database/data_DIA_24_01_aligned.pkl')
+    data_5['file'] = 5
     data_6 = pd.read_pickle('database/data_DIA_30_01_aligned.pkl')
+    data_6['file'] = 6
     data = pd.concat([data_1, data_2, data_3, data_4, data_5, data_6], ignore_index=True)
 
     num_total = len(data)