diff --git a/main_custom.py b/main_custom.py index 8bcce6f6a2598e4530e6ea721fde157abc3a6df8..514c0bcc85ce59048599adced37e45518b53f384 100644 --- a/main_custom.py +++ b/main_custom.py @@ -272,15 +272,15 @@ def get_n_params(model): return pp def save_pred(model, data_val, forward, output_path, file = False): - if file : - data_val.dataset.set_file_mode(True) + data_frame = pd.DataFrame() model.eval() for param in model.parameters(): param.requires_grad = False if forward == 'both': pred_rt, pred_int, seqs, charges, true_rt, true_int, file_list = [], [], [], [], [], [], [] - data_val.data.set_file_mode(True) + if file: + data_val.dataset.set_file_mode(True) for seq, charge, rt, intensity, file in data_val: rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available():