diff --git a/main_custom.py b/main_custom.py index 0cfc1cc5a893c67f23bf40ec43eeb7a24a8945e5..994b6019d0c19e7f6f3602a767cba6d8f7bcbd56 100644 --- a/main_custom.py +++ b/main_custom.py @@ -284,7 +284,7 @@ def save_pred(model, data_val, forward, output_path, file = False): for seq, charge, rt, intensity, files in data_val: rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): - seq, charge, rt, intensity, file = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda(), file.cuda() + seq, charge, rt, intensity, files = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda(), files.cuda() pr_rt, pr_intensity = model.forward(seq, charge) pred_rt.extend(pr_rt.data.cpu().tolist()) pred_int.extend(pr_intensity.data.cpu().tolist())