Skip to content
Snippets Groups Projects
Commit fb2b1c63 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix save preds with file

parent 404f78b9
No related branches found
No related tags found
No related merge requests found
......@@ -281,8 +281,8 @@ def save_pred(model, data_val, forward, output_path, file = False):
pred_rt, pred_int, seqs, charges, true_rt, true_int, file_list = [], [], [], [], [], [], []
if file:
data_val.dataset.set_file_mode(True)
for seq, charge, rt, intensity, file in data_val:
rt, intensity = rt.float(), intensity.float()
for seq, charge, rt, intensity, file in data_val:
rt, intensity = rt.float(), intensity.float()
if torch.cuda.is_available():
seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
pr_rt, pr_intensity = model.forward(seq, charge)
......@@ -293,13 +293,27 @@ def save_pred(model, data_val, forward, output_path, file = False):
true_rt.extend(rt.data.cpu().tolist())
true_int.extend(intensity.data.cpu().tolist())
file_list.extend([file])
else :
for seq, charge, rt, intensity in data_val:
rt, intensity = rt.float(), intensity.float()
if torch.cuda.is_available():
seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
pr_rt, pr_intensity = model.forward(seq, charge)
pred_rt.extend(pr_rt.data.cpu().tolist())
pred_int.extend(pr_intensity.data.cpu().tolist())
seqs.extend(seq.data.cpu().tolist())
charges.extend(charge.data.cpu().tolist())
true_rt.extend(rt.data.cpu().tolist())
true_int.extend(intensity.data.cpu().tolist())
data_frame['rt pred'] = pred_rt
data_frame['seq'] = seqs
data_frame['pred int'] = pred_int
data_frame['true rt'] = true_rt
data_frame['true int'] = true_int
data_frame['charge'] = charges
data_frame['file'] = file_list
if file :
data_frame['file'] = file_list
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment