From 2a70c389317100b528069ce7eef77e5ef7e589c2 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 30 Sep 2024 10:30:29 +0200 Subject: [PATCH] fix save preds with file --- main_custom.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/main_custom.py b/main_custom.py index 7985488..65b27ff 100644 --- a/main_custom.py +++ b/main_custom.py @@ -285,26 +285,26 @@ def save_pred(model, data_val, forward, output_path, file = False): rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() - pr_rt, pr_intensity = model.forward(seq, charge) - pred_rt.extend(pr_rt.data.cpu().tolist()) - pred_int.extend(pr_intensity.data.cpu().tolist()) - seqs.extend(seq.data.cpu().tolist()) - charges.extend(charge.data.cpu().tolist()) - true_rt.extend(rt.data.cpu().tolist()) - true_int.extend(intensity.data.cpu().tolist()) - file_list.extend([files]) + pr_rt, pr_intensity = model.forward(seq, charge) + pred_rt.extend(pr_rt.data.cpu().tolist()) + pred_int.extend(pr_intensity.data.cpu().tolist()) + seqs.extend(seq.data.cpu().tolist()) + charges.extend(charge.data.cpu().tolist()) + true_rt.extend(rt.data.cpu().tolist()) + true_int.extend(intensity.data.cpu().tolist()) + file_list.extend([files]) else : for seq, charge, rt, intensity in data_val: rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() - pr_rt, pr_intensity = model.forward(seq, charge) - pred_rt.extend(pr_rt.data.cpu().tolist()) - pred_int.extend(pr_intensity.data.cpu().tolist()) - seqs.extend(seq.data.cpu().tolist()) - charges.extend(charge.data.cpu().tolist()) - true_rt.extend(rt.data.cpu().tolist()) - true_int.extend(intensity.data.cpu().tolist()) + pr_rt, pr_intensity = model.forward(seq, charge) + pred_rt.extend(pr_rt.data.cpu().tolist()) + pred_int.extend(pr_intensity.data.cpu().tolist()) + seqs.extend(seq.data.cpu().tolist()) + charges.extend(charge.data.cpu().tolist()) + true_rt.extend(rt.data.cpu().tolist()) + true_int.extend(intensity.data.cpu().tolist()) data_frame['rt pred'] = pred_rt data_frame['seq'] = seqs -- GitLab