Skip to content
Snippets Groups Projects
Commit 90bf2b77 authored by Léo Schneider's avatar Léo Schneider
Browse files

save preds

parent 6ac8b84f
No related branches found
No related tags found
No related merge requests found
import os
import pandas as pd
import torch
import torch.optim as optim
import wandb as wdb
......@@ -246,8 +248,68 @@ def get_n_params(model):
pp += nn
return pp
def save_pred(model, data_val, forward, output_path):
data_frame = pd.DataFrame()
for param in model.parameters():
param.requires_grad = False
if forward == 'both':
pred_rt, pred_int, seqs, charges, true_rt, true_int = [], [], [], [], [], []
for seq, charge, rt, intensity in data_val:
rt, intensity = rt.float(), intensity.float()
if torch.cuda.is_available():
seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
pr_rt, pr_intensity = model.forward(seq, charge)
pred_rt.extend(pr_rt.data.cpu().tolist())
pred_int.extend(pr_intensity.data.cpu().tolist())
seqs.extend(seq.data.cpu().tolist())
charges.extend(charge.data.cpu().tolist())
true_rt.extend(rt.data.cpu().tolist())
true_int.extend(intensity.data.cpu().tolist())
data_frame['rt pred'] = pred_rt
data_frame['seq'] = seqs
data_frame['pred int'] = pred_int
data_frame['true rt'] = true_rt
data_frame['true int'] = true_int
data_frame['charge'] = charges
if forward == 'rt': #adapted to prosit dataset format
pred_rt, seqs, true_rt = [], [], []
for seq, rt in data_val:
rt = rt.float()
if torch.cuda.is_available():
seq, rt = seq.cuda(), rt.cuda()
pr_rt = model.forward_rt(seq)
pred_rt.extend(pr_rt.data.cpu().tolist())
seqs.extend(seq.data.cpu().tolist())
true_rt.extend(rt.data.cpu().tolist())
data_frame['rt pred'] = pred_rt
data_frame['seq'] = seqs
data_frame['true rt'] = true_rt
if forward == 'int': #adapted to prosit dataset format
pred_int, seqs, charges, true_int = [], [], [], []
for seq, charge, _, intensity in data_val:
intensity = intensity.float()
if torch.cuda.is_available():
seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
pred_int = model.forward_int(seq, charge)
seqs.extend(seq.data.cpu().tolist())
charges.extend(charge.data.cpu().tolist())
true_int.extend(intensity.data.cpu().tolist())
data_frame['seq'] = seqs
data_frame['pred int'] = pred_int
data_frame['true int'] = true_int
data_frame['charge'] = charges
data_frame.to_csv(output_path)
if __name__ == "__main__":
args = load_args()
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment