From 48761f6f0f955aabb3b312442b42e74b9afd839d Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Mon, 30 Sep 2024 10:13:43 +0200
Subject: [PATCH] fix save preds with file

---
 main_custom.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/main_custom.py b/main_custom.py
index ff5a2b1..7985488 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -200,7 +200,6 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test,
 
 
 def main(args):
-    print('file', args.file)
     if args.wandb is not None:
         os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
         os.environ["WANDB_MODE"] = "offline"
@@ -282,7 +281,7 @@ def save_pred(model, data_val, forward, output_path, file = False):
         pred_rt, pred_int, seqs, charges, true_rt, true_int, file_list = [], [], [], [], [], [], []
         if file:
             data_val.dataset.set_file_mode(True)
-            for seq, charge, rt, intensity, file in data_val:
+            for seq, charge, rt, intensity, files in data_val:
                 rt, intensity = rt.float(), intensity.float()
             if torch.cuda.is_available():
                 seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
@@ -293,7 +292,7 @@ def save_pred(model, data_val, forward, output_path, file = False):
             charges.extend(charge.data.cpu().tolist())
             true_rt.extend(rt.data.cpu().tolist())
             true_int.extend(intensity.data.cpu().tolist())
-            file_list.extend([file])
+            file_list.extend([files])
         else :
             for seq, charge, rt, intensity in data_val:
                 rt, intensity = rt.float(), intensity.float()
-- 
GitLab