diff --git a/main_custom.py b/main_custom.py index 994b6019d0c19e7f6f3602a767cba6d8f7bcbd56..7de21bdad583b8cfe578e8526ce550d15b2cc3b3 100644 --- a/main_custom.py +++ b/main_custom.py @@ -346,7 +346,8 @@ def save_pred(model, data_val, forward, output_path, file = False): data_frame['pred int'] = pred_int data_frame['true int'] = true_int data_frame['charge'] = charges - data_val.data.set_file_mode(False) + if file : + data_val.dataset.set_file_mode(False) data_frame.to_csv(output_path)