diff --git a/main_custom.py b/main_custom.py index b0866f7bdb1c1051dbf188a9999d50fb23ef3f2f..8bcce6f6a2598e4530e6ea721fde157abc3a6df8 100644 --- a/main_custom.py +++ b/main_custom.py @@ -273,7 +273,7 @@ def get_n_params(model): def save_pred(model, data_val, forward, output_path, file = False): if file : - data_val.data.set_file_mode(True) + data_val.dataset.set_file_mode(True) data_frame = pd.DataFrame() model.eval() for param in model.parameters():