diff --git a/dependencies/tiny-cuda-nn b/dependencies/tiny-cuda-nn
index 6a835fd7ed8f76cd7ac0a9744b79da8b67e17c14..8e6e242f36dd197134c9b9275a8e5108a8e3af78 160000
--- a/dependencies/tiny-cuda-nn
+++ b/dependencies/tiny-cuda-nn
@@ -1 +1 @@
-Subproject commit 6a835fd7ed8f76cd7ac0a9744b79da8b67e17c14
+Subproject commit 8e6e242f36dd197134c9b9275a8e5108a8e3af78
diff --git a/include/neural-graphics-primitives/nerf_network.h b/include/neural-graphics-primitives/nerf_network.h
index 9344046b55ceb2cd150bee7da2df11a1b79edfc6..4649472bb77c5ecaf1dde7d8cb5ca19611120fae 100644
--- a/include/neural-graphics-primitives/nerf_network.h
+++ b/include/neural-graphics-primitives/nerf_network.h
@@ -358,86 +358,33 @@ public:
 		}
 	}
 
-	void set_params(T* params, T* inference_params, T* backward_params, T* gradients) override {
+	void set_params_impl(T* params, T* inference_params, T* gradients) override {
 		size_t offset = 0;
-		m_density_network->set_params(
-			params + offset,
-			inference_params + offset,
-			backward_params + offset,
-			gradients + offset
-		);
+		m_density_network->set_params(params + offset, inference_params + offset, gradients + offset);
 		offset += m_density_network->n_params();
 
-		m_rgb_network->set_params(
-			params + offset,
-			inference_params + offset,
-			backward_params + offset,
-			gradients + offset
-		);
+		m_rgb_network->set_params(params + offset, inference_params + offset, gradients + offset);
 		offset += m_rgb_network->n_params();
 
-		m_pos_encoding->set_params(
-			params + offset,
-			inference_params + offset,
-			backward_params + offset,
-			gradients + offset
-		);
+		m_pos_encoding->set_params(params + offset, inference_params + offset, gradients + offset);
 		offset += m_pos_encoding->n_params();
 
-		m_dir_encoding->set_params(
-			params + offset,
-			inference_params + offset,
-			backward_params + offset,
-			gradients + offset
-		);
+		m_dir_encoding->set_params(params + offset, inference_params + offset, gradients + offset);
 		offset += m_dir_encoding->n_params();
 	}
 
-	void initialize_params(tcnn::pcg32& rnd, float* params_full_precision, T* params, T* inference_params, T* backward_params, T* gradients, float scale = 1) override {
-		size_t offset = 0;
-		m_density_network->initialize_params(
-			rnd,
-			params_full_precision + offset,
-			params + offset,
-			inference_params + offset,
-			backward_params + offset,
-			gradients + offset,
-			scale
-		);
-		offset += m_density_network->n_params();
+	void initialize_params(tcnn::pcg32& rnd, float* params_full_precision, float scale = 1) override {
+		m_density_network->initialize_params(rnd, params_full_precision, scale);
+		params_full_precision += m_density_network->n_params();
 
-		m_rgb_network->initialize_params(
-			rnd,
-			params_full_precision + offset,
-			params + offset,
-			inference_params + offset,
-			backward_params + offset,
-			gradients + offset,
-			scale
-		);
-		offset += m_rgb_network->n_params();
+		m_rgb_network->initialize_params(rnd, params_full_precision, scale);
+		params_full_precision += m_rgb_network->n_params();
 
-		m_pos_encoding->initialize_params(
-			rnd,
-			params_full_precision + offset,
-			params + offset,
-			inference_params + offset,
-			backward_params + offset,
-			gradients + offset,
-			scale
-		);
-		offset += m_pos_encoding->n_params();
+		m_pos_encoding->initialize_params(rnd, params_full_precision, scale);
+		params_full_precision += m_pos_encoding->n_params();
 
-		m_dir_encoding->initialize_params(
-			rnd,
-			params_full_precision + offset,
-			params + offset,
-			inference_params + offset,
-			backward_params + offset,
-			gradients + offset,
-			scale
-		);
-		offset += m_dir_encoding->n_params();
+		m_dir_encoding->initialize_params(rnd, params_full_precision, scale);
+		params_full_precision += m_dir_encoding->n_params();
 	}
 
 	size_t n_params() const override {
diff --git a/include/neural-graphics-primitives/takikawa_encoding.cuh b/include/neural-graphics-primitives/takikawa_encoding.cuh
index 1167464b398181f346d94b602538e46819b3bfe1..ac6f6edded4abf0f5c701344ed73a405c0bc8290 100644
--- a/include/neural-graphics-primitives/takikawa_encoding.cuh
+++ b/include/neural-graphics-primitives/takikawa_encoding.cuh
@@ -330,7 +330,7 @@ public:
 			m_interpolation_type,
 			m_octree->nodes_gpu(),
 			m_octree->dual_nodes_gpu(),
-			use_inference_params ? m_params_inference : m_params,
+			use_inference_params ? this->inference_params() : this->params(),
 			input.view(),
 			output ? output->view() : tcnn::MatrixView<T>{},
 			forward->dy_dx.data()
@@ -366,7 +366,7 @@ public:
 				params_gradient_tmp = tcnn::allocate_workspace(stream, n_params() * sizeof(grad_t));
 				params_gradient = (grad_t*)params_gradient_tmp.data();
 			} else {
-				params_gradient = (grad_t*)m_params_gradient;
+				params_gradient = (grad_t*)this->gradients();
 			}
 
 			if (param_gradients_mode == tcnn::EGradientMode::Overwrite) {
@@ -386,7 +386,7 @@ public:
 			);
 
 			if (!std::is_same<grad_t, T>::value) {
-				parallel_for_gpu(stream, n_params(), [grad=m_params_gradient, grad_tmp=params_gradient] __device__ (size_t i) {
+				parallel_for_gpu(stream, n_params(), [grad=this->gradients(), grad_tmp=params_gradient] __device__ (size_t i) {
 					grad[i] = (T)grad_tmp[i];
 				});
 			}
@@ -428,17 +428,11 @@ public:
 		return N_FEATURES_PER_LEVEL;
 	}
 
-	void set_params(T* params, T* inference_params, T* backward_params, T* gradients) override {
-		m_params = params;
-		m_params_inference = inference_params;
-		m_params_gradient = gradients;
-	}
-
-	void initialize_params(tcnn::pcg32& rnd, float* params_full_precision, T* params, T* inference_params, T* backward_params, T* gradients, float scale = 1) override {
-		set_params(params, inference_params, backward_params, gradients);
+	void set_params_impl(T* params, T* inference_params, T* gradients) override { }
 
+	void initialize_params(tcnn::pcg32& rnd, float* params_full_precision, float scale = 1) override {
 		// Initialize the encoding from the GPU, because the number of parameters can be quite large.
-		tcnn::generate_random_uniform<float>(rnd, n_params(), params_full_precision, -1e-4f, 1e-4f);
+		tcnn::generate_random_uniform<float>(rnd, n_params(), params_full_precision, -1e-4f * scale, 1e-4f * scale);
 	}
 
 	size_t n_params() const override {
@@ -473,11 +467,6 @@ private:
 	uint32_t m_n_output_dims;
 	uint32_t m_n_to_pad = 0;
 
-	// Storage of params
-	T* m_params;
-	T* m_params_inference;
-	T* m_params_gradient;
-
 	std::shared_ptr<TriangleOctree> m_octree;
 	tcnn::InterpolationType m_interpolation_type;
 };
diff --git a/include/neural-graphics-primitives/trainable_buffer.cuh b/include/neural-graphics-primitives/trainable_buffer.cuh
index 8ea9125baf952dc00580bd16487a6318b1f23850..0672b18aceb93eccbf18a906a276b414d12cb516 100644
--- a/include/neural-graphics-primitives/trainable_buffer.cuh
+++ b/include/neural-graphics-primitives/trainable_buffer.cuh
@@ -58,15 +58,9 @@ public:
 		throw std::runtime_error{"The trainable buffer does not support backward(). Its content is meant to be used externally."};
 	}
 
-	void set_params(T* params, T* inference_params, T* backward_params, T* gradients) override {
-		m_params = params;
-		m_params_inference = inference_params;
-		m_params_gradient = gradients;
-	}
-
-	void initialize_params(tcnn::pcg32& rnd, float* params_full_precision, T* params, T* inference_params, T* backward_params, T* gradients, float scale = 1) override {
-		set_params(params, inference_params, backward_params, gradients);
+	void set_params_impl(T* params, T* inference_params, T* gradients) override { }
 
+	void initialize_params(tcnn::pcg32& rnd, float* params_full_precision, float scale = 1) override {
 		// Initialize the buffer to zero from the GPU
 		CUDA_CHECK_THROW(cudaMemset(params_full_precision, 0, n_params()*sizeof(float)));
 	}
@@ -95,22 +89,10 @@ public:
 		return {};
 	}
 
-	T* gradients() const {
-		return m_params_gradient;
-	}
-
 	T* gradient_weights() const {
 		return m_params_gradient_weight.data();
 	}
 
-	T* params() const {
-		return m_params;
-	}
-
-	T* params_inference() const {
-		return m_params_inference;
-	}
-
 	tcnn::json hyperparams() const override {
 		return {
 			{"otype", "TrainableBuffer"},
@@ -119,10 +101,6 @@ public:
 
 private:
 	ResVector m_resolution;
-
-	T* m_params = nullptr;
-	T* m_params_inference = nullptr;
-	T* m_params_gradient = nullptr;
 	tcnn::GPUMemory<T> m_params_gradient_weight;
 };
 
diff --git a/src/testbed_nerf.cu b/src/testbed_nerf.cu
index 1406f9d5d351ac390cf1f0ccfa8cae4ff3dd293b..2f20fadeb14ebb9d5444e83c7b0fc4b5f09e3cc2 100644
--- a/src/testbed_nerf.cu
+++ b/src/testbed_nerf.cu
@@ -2278,9 +2278,9 @@ void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_r
 		plane_z,
 		m_aperture_size,
 		lens,
-		m_envmap.envmap->params_inference(),
+		m_envmap.envmap->inference_params(),
 		m_envmap.resolution,
-		render_grid_distortion ? m_distortion.map->params_inference() : nullptr,
+		render_grid_distortion ? m_distortion.map->inference_params() : nullptr,
 		m_distortion.resolution,
 		render_buffer.frame_buffer(),
 		render_buffer.depth_buffer(),
@@ -3454,7 +3454,8 @@ int Testbed::marching_cubes(Vector3i res3d, const BoundingBox& aabb, const Matri
 	m_mesh.verts_gradient.copy_from_device(m_mesh.verts); // Make sure the vertices don't get destroyed in the initialization
 
 	pcg32 rnd{m_seed};
-	m_mesh.trainable_verts->initialize_params(rnd, (float*)m_mesh.verts.data(), (float*)m_mesh.verts.data(), (float*)m_mesh.verts.data(), (float*)m_mesh.verts.data(), (float*)m_mesh.verts_gradient.data());
+	m_mesh.trainable_verts->initialize_params(rnd, (float*)m_mesh.verts.data());
+	m_mesh.trainable_verts->set_params((float*)m_mesh.verts.data(), (float*)m_mesh.verts.data(), (float*)m_mesh.verts_gradient.data());
 	m_mesh.verts.copy_from_device(m_mesh.verts_gradient);
 
 	m_mesh.verts_optimizer.reset(create_optimizer<float>({
diff --git a/src/testbed_sdf.cu b/src/testbed_sdf.cu
index f7c191222bc85f6e358f025f58f4e402f78b4cd0..780dfb20200a034040582e08de01036d263b6153 100644
--- a/src/testbed_sdf.cu
+++ b/src/testbed_sdf.cu
@@ -850,7 +850,7 @@ void Testbed::render_sdf(
 		m_render_near_distance,
 		plane_z,
 		m_aperture_size,
-		m_envmap.envmap->params_inference(),
+		m_envmap.envmap->inference_params(),
 		m_envmap.resolution,
 		render_buffer.frame_buffer(),
 		render_buffer.depth_buffer(),
diff --git a/src/testbed_volume.cu b/src/testbed_volume.cu
index a5b9d7363cdc59161ba52e297e0cd43d84b6c7bd..6559f08b0ac9b4a255f6b0f1d31f6f90b538160b 100644
--- a/src/testbed_volume.cu
+++ b/src/testbed_volume.cu
@@ -427,7 +427,7 @@ void Testbed::render_volume(CudaRenderBuffer& render_buffer,
 		m_render_near_distance,
 		plane_z,
 		m_aperture_size,
-		m_envmap.envmap->params_inference(),
+		m_envmap.envmap->inference_params(),
 		m_envmap.resolution,
 		render_buffer.frame_buffer(),
 		render_buffer.depth_buffer(),