From 2235c70bf8e1c7f15f0f7211b356b5b3a4ecfd2f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thomas=20M=C3=BCller?= <tmueller@nvidia.com>
Date: Wed, 18 Jan 2023 10:41:12 +0100
Subject: [PATCH] Python API improvements

---
 src/python_api.cu | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/python_api.cu b/src/python_api.cu
index 91d32c4..f69fc8c 100644
--- a/src/python_api.cu
+++ b/src/python_api.cu
@@ -365,7 +365,7 @@ PYBIND11_MODULE(pyngp, m) {
 			py::arg("shutter_fraction") = 1.0f
 		)
 		.def("destroy_window", &Testbed::destroy_window, "Destroy the window again.")
-		.def("train", &Testbed::train, py::call_guard<py::gil_scoped_release>(), "Perform a specified number of training steps.")
+		.def("train", &Testbed::train, py::call_guard<py::gil_scoped_release>(), "Perform a single training step with a specified batch size.")
 		.def("reset", &Testbed::reset_network, py::arg("reset_density_grid") = true, "Reset training.")
 		.def("reset_accumulation", &Testbed::reset_accumulation, "Reset rendering accumulation.",
 			py::arg("due_to_camera_movement") = false,
@@ -385,7 +385,7 @@ PYBIND11_MODULE(pyngp, m) {
 		)
 		.def("n_params", &Testbed::n_params, "Number of trainable parameters")
 		.def("n_encoding_params", &Testbed::n_encoding_params, "Number of trainable parameters in the encoding")
-		.def("save_snapshot", &Testbed::save_snapshot, py::arg("path"), py::arg("include_optimizer_state")=false, "Save a snapshot of the currently trained model")
+		.def("save_snapshot", &Testbed::save_snapshot, py::arg("path"), py::arg("include_optimizer_state")=false, py::arg("compress")=true, "Save a snapshot of the currently trained model. Optionally compressed (only when saving '.ingp' files).")
 		.def("load_snapshot", &Testbed::load_snapshot, py::arg("path"), "Load a previously saved snapshot")
 		.def("load_camera_path", &Testbed::load_camera_path, "Load a camera path", py::arg("path"))
 		.def("load_file", &Testbed::load_file, "Load a file and automatically determine how to handle it. Can be a snapshot, dataset, network config, or camera path.", py::arg("path"))
@@ -428,6 +428,7 @@ PYBIND11_MODULE(pyngp, m) {
 		.def_readwrite("shall_train_encoding", &Testbed::m_train_encoding)
 		.def_readwrite("shall_train_network", &Testbed::m_train_network)
 		.def_readwrite("render_groundtruth", &Testbed::m_render_ground_truth)
+		.def_readwrite("render_ground_truth", &Testbed::m_render_ground_truth)
 		.def_readwrite("groundtruth_render_mode", &Testbed::m_ground_truth_render_mode)
 		.def_readwrite("render_mode", &Testbed::m_render_mode)
 		.def_readwrite("render_near_distance", &Testbed::m_render_near_distance)
@@ -546,11 +547,12 @@ PYBIND11_MODULE(pyngp, m) {
 
 	py::class_<TrainingImageMetadata> metadata(m, "TrainingImageMetadata");
 	metadata
-		.def_readwrite("focal_length", &TrainingImageMetadata::focal_length)
 		// Legacy member: lens used to be called "camera_distortion"
 		.def_readwrite("camera_distortion", &TrainingImageMetadata::lens)
 		.def_readwrite("lens", &TrainingImageMetadata::lens)
+		.def_readwrite("resolution", &TrainingImageMetadata::resolution)
 		.def_readwrite("principal_point", &TrainingImageMetadata::principal_point)
+		.def_readwrite("focal_length", &TrainingImageMetadata::focal_length)
 		.def_readwrite("rolling_shutter", &TrainingImageMetadata::rolling_shutter)
 		.def_readwrite("light_dir", &TrainingImageMetadata::light_dir)
 		;
-- 
GitLab