diff --git a/main_custom.py b/main_custom.py
index 94bf357516ba9761d9aed6123c0def129c588cea..672f886b048081b12903a77ffbf7e3dd7e662454 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -1,4 +1,6 @@
 import os
+
+import pandas as pd
 import torch
 import torch.optim as optim
 import wandb as wdb
@@ -246,8 +248,68 @@ def get_n_params(model):
         pp += nn
     return pp
 
+def save_pred(model, data_val, forward, output_path):
+    data_frame = pd.DataFrame()
+    for param in model.parameters():
+        param.requires_grad = False
+    if forward == 'both':
+        pred_rt, pred_int, seqs, charges, true_rt, true_int = [], [], [], [], [], []
+        for seq, charge, rt, intensity in data_val:
+            rt, intensity = rt.float(), intensity.float()
+            if torch.cuda.is_available():
+                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
+            pr_rt, pr_intensity = model.forward(seq, charge)
+            pred_rt.extend(pr_rt.data.cpu().tolist())
+            pred_int.extend(pr_intensity.data.cpu().tolist())
+            seqs.extend(seq.data.cpu().tolist())
+            charges.extend(charge.data.cpu().tolist())
+            true_rt.extend(rt.data.cpu().tolist())
+            true_int.extend(intensity.data.cpu().tolist())
+            data_frame['rt pred'] = pred_rt
+            data_frame['seq'] = seqs
+            data_frame['pred int'] = pred_int
+            data_frame['true rt'] = true_rt
+            data_frame['true int'] = true_int
+            data_frame['charge'] = charges
+
+
+
+    if forward == 'rt':  #adapted to prosit dataset format
+        pred_rt, seqs, true_rt = [], [], []
+        for seq, rt in data_val:
+            rt = rt.float()
+            if torch.cuda.is_available():
+                seq, rt = seq.cuda(), rt.cuda()
+            pr_rt = model.forward_rt(seq)
+            pred_rt.extend(pr_rt.data.cpu().tolist())
+            seqs.extend(seq.data.cpu().tolist())
+            true_rt.extend(rt.data.cpu().tolist())
+            data_frame['rt pred'] = pred_rt
+            data_frame['seq'] = seqs
+            data_frame['true rt'] = true_rt
+
+
+    if forward == 'int':  #adapted to prosit dataset format
+        pred_int, seqs, charges, true_int = [], [], [], []
+        for seq, charge, _, intensity in data_val:
+            intensity = intensity.float()
+            if torch.cuda.is_available():
+                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
+            pred_int = model.forward_int(seq, charge)
+            seqs.extend(seq.data.cpu().tolist())
+            charges.extend(charge.data.cpu().tolist())
+            true_int.extend(intensity.data.cpu().tolist())
+            data_frame['seq'] = seqs
+            data_frame['pred int'] = pred_int
+            data_frame['true int'] = true_int
+            data_frame['charge'] = charges
+
+    data_frame.to_csv(output_path)
+
 
 if __name__ == "__main__":
     args = load_args()
     main(args)
 
+
+