From 2a70c389317100b528069ce7eef77e5ef7e589c2 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Mon, 30 Sep 2024 10:30:29 +0200
Subject: [PATCH] fix save preds with file

---
 main_custom.py | 30 +++++++++++++++---------------
 1 file changed, 15 insertions(+), 15 deletions(-)

diff --git a/main_custom.py b/main_custom.py
index 7985488..65b27ff 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -285,26 +285,26 @@ def save_pred(model, data_val, forward, output_path, file = False):
                 rt, intensity = rt.float(), intensity.float()
             if torch.cuda.is_available():
                 seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
-            pr_rt, pr_intensity = model.forward(seq, charge)
-            pred_rt.extend(pr_rt.data.cpu().tolist())
-            pred_int.extend(pr_intensity.data.cpu().tolist())
-            seqs.extend(seq.data.cpu().tolist())
-            charges.extend(charge.data.cpu().tolist())
-            true_rt.extend(rt.data.cpu().tolist())
-            true_int.extend(intensity.data.cpu().tolist())
-            file_list.extend([files])
+                pr_rt, pr_intensity = model.forward(seq, charge)
+                pred_rt.extend(pr_rt.data.cpu().tolist())
+                pred_int.extend(pr_intensity.data.cpu().tolist())
+                seqs.extend(seq.data.cpu().tolist())
+                charges.extend(charge.data.cpu().tolist())
+                true_rt.extend(rt.data.cpu().tolist())
+                true_int.extend(intensity.data.cpu().tolist())
+                file_list.extend([files])
         else :
             for seq, charge, rt, intensity in data_val:
                 rt, intensity = rt.float(), intensity.float()
             if torch.cuda.is_available():
                 seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
-            pr_rt, pr_intensity = model.forward(seq, charge)
-            pred_rt.extend(pr_rt.data.cpu().tolist())
-            pred_int.extend(pr_intensity.data.cpu().tolist())
-            seqs.extend(seq.data.cpu().tolist())
-            charges.extend(charge.data.cpu().tolist())
-            true_rt.extend(rt.data.cpu().tolist())
-            true_int.extend(intensity.data.cpu().tolist())
+                pr_rt, pr_intensity = model.forward(seq, charge)
+                pred_rt.extend(pr_rt.data.cpu().tolist())
+                pred_int.extend(pr_intensity.data.cpu().tolist())
+                seqs.extend(seq.data.cpu().tolist())
+                charges.extend(charge.data.cpu().tolist())
+                true_rt.extend(rt.data.cpu().tolist())
+                true_int.extend(intensity.data.cpu().tolist())
 
         data_frame['rt pred'] = pred_rt
         data_frame['seq'] = seqs
-- 
GitLab