From 25dcf8e1491dfebc0a16ad163d843a8bdebe00d9 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Mon, 30 Sep 2024 10:57:22 +0200
Subject: [PATCH] fix save preds with file

---
 main_custom.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/main_custom.py b/main_custom.py
index 65b27ff..dd61d47 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -283,8 +283,8 @@ def save_pred(model, data_val, forward, output_path, file = False):
             data_val.dataset.set_file_mode(True)
             for seq, charge, rt, intensity, files in data_val:
                 rt, intensity = rt.float(), intensity.float()
-            if torch.cuda.is_available():
-                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
+                if torch.cuda.is_available():
+                    seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
                 pr_rt, pr_intensity = model.forward(seq, charge)
                 pred_rt.extend(pr_rt.data.cpu().tolist())
                 pred_int.extend(pr_intensity.data.cpu().tolist())
@@ -296,8 +296,8 @@ def save_pred(model, data_val, forward, output_path, file = False):
         else :
             for seq, charge, rt, intensity in data_val:
                 rt, intensity = rt.float(), intensity.float()
-            if torch.cuda.is_available():
-                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
+                if torch.cuda.is_available():
+                    seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
                 pr_rt, pr_intensity = model.forward(seq, charge)
                 pred_rt.extend(pr_rt.data.cpu().tolist())
                 pred_int.extend(pr_intensity.data.cpu().tolist())
-- 
GitLab