diff --git a/main_custom.py b/main_custom.py index 65b27ff39a61c7c771d609de13053b44c51be7c2..dd61d477123f868d449a69f64bd98fdc77344546 100644 --- a/main_custom.py +++ b/main_custom.py @@ -283,8 +283,8 @@ def save_pred(model, data_val, forward, output_path, file = False): data_val.dataset.set_file_mode(True) for seq, charge, rt, intensity, files in data_val: rt, intensity = rt.float(), intensity.float() - if torch.cuda.is_available(): - seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() + if torch.cuda.is_available(): + seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() pr_rt, pr_intensity = model.forward(seq, charge) pred_rt.extend(pr_rt.data.cpu().tolist()) pred_int.extend(pr_intensity.data.cpu().tolist()) @@ -296,8 +296,8 @@ def save_pred(model, data_val, forward, output_path, file = False): else : for seq, charge, rt, intensity in data_val: rt, intensity = rt.float(), intensity.float() - if torch.cuda.is_available(): - seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() + if torch.cuda.is_available(): + seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() pr_rt, pr_intensity = model.forward(seq, charge) pred_rt.extend(pr_rt.data.cpu().tolist()) pred_int.extend(pr_intensity.data.cpu().tolist())