From 1bbfb0ca1d199586981b3a8d5bf538c006864ae9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thomas=20M=C3=BCller?= <thomas94@gmx.net>
Date: Tue, 10 Jan 2023 10:39:19 +0100
Subject: [PATCH] Work around internal CUTLASS error upon large batch sizes

---
 include/neural-graphics-primitives/nerf.h |  4 ++
 src/testbed_nerf.cu                       | 71 ++++++++++++++---------
 2 files changed, 48 insertions(+), 27 deletions(-)

diff --git a/include/neural-graphics-primitives/nerf.h b/include/neural-graphics-primitives/nerf.h
index 27e3237..5cd2d2b 100644
--- a/include/neural-graphics-primitives/nerf.h
+++ b/include/neural-graphics-primitives/nerf.h
@@ -25,6 +25,10 @@ inline constexpr __device__ uint32_t NERF_GRIDSIZE() {
 	return 128;
 }
 
+inline constexpr __device__ uint32_t NERF_GRID_N_CELLS() {
+	return NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE();
+}
+
 struct NerfPayload {
 	Eigen::Vector3f origin;
 	Eigen::Vector3f dir;
diff --git a/src/testbed_nerf.cu b/src/testbed_nerf.cu
index 8f116c8..1b4624d 100644
--- a/src/testbed_nerf.cu
+++ b/src/testbed_nerf.cu
@@ -75,7 +75,7 @@ Testbed::NetworkDims Testbed::network_dims_nerf() const {
 }
 
 inline __host__ __device__ uint32_t grid_mip_offset(uint32_t mip) {
-	return (NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE()) * mip;
+	return NERF_GRID_N_CELLS() * mip;
 }
 
 inline __host__ __device__ float calc_cone_angle(float cosine, const Eigen::Vector2f& focal_length, float cone_angle_constant) {
@@ -371,8 +371,8 @@ __global__ void mark_untrained_density_grid(const uint32_t n_elements,  float* _
 	const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
 	if (i >= n_elements) return;
 
-	uint32_t level = i / (NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE());
-	uint32_t pos_idx = i % (NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE());
+	uint32_t level = i / NERF_GRID_N_CELLS();
+	uint32_t pos_idx = i % NERF_GRID_N_CELLS();
 
 	uint32_t x = tcnn::morton3D_invert(pos_idx>>0);
 	uint32_t y = tcnn::morton3D_invert(pos_idx>>1);
@@ -414,10 +414,12 @@ __global__ void generate_grid_samples_nerf_uniform(Eigen::Vector3i res_3d, const
 	uint32_t x = threadIdx.x + blockIdx.x * blockDim.x;
 	uint32_t y = threadIdx.y + blockIdx.y * blockDim.y;
 	uint32_t z = threadIdx.z + blockIdx.z * blockDim.z;
-	if (x>=res_3d.x() || y>=res_3d.y() || z>=res_3d.z())
+	if (x >= res_3d.x() || y >= res_3d.y() || z >= res_3d.z()) {
 		return;
-	uint32_t i = x+ y*res_3d.x() + z*res_3d.x()*res_3d.y();
-	Vector3f pos = Vector3f{(float)x, (float)y, (float)z}.cwiseQuotient((res_3d-Vector3i::Ones()).cast<float>());
+	}
+
+	uint32_t i = x + y * res_3d.x() + z * res_3d.x() * res_3d.y();
+	Vector3f pos = Vector3f{(float)x, (float)y, (float)z}.cwiseQuotient((res_3d - Vector3i::Ones()).cast<float>());
 	pos = render_aabb_to_local.transpose() * (pos.cwiseProduct(render_aabb.max - render_aabb.min) + render_aabb.min);
 	out[i] = { warp_position(pos, train_aabb), warp_dt(MIN_CONE_STEPSIZE()) };
 }
@@ -428,14 +430,18 @@ __global__ void generate_grid_samples_nerf_uniform_dir(Eigen::Vector3i res_3d, c
 	uint32_t x = threadIdx.x + blockIdx.x * blockDim.x;
 	uint32_t y = threadIdx.y + blockIdx.y * blockDim.y;
 	uint32_t z = threadIdx.z + blockIdx.z * blockDim.z;
-	if (x>=res_3d.x() || y>=res_3d.y() || z>=res_3d.z())
+	if (x >= res_3d.x() || y >= res_3d.y() || z >= res_3d.z()) {
 		return;
+	}
+
 	uint32_t i = x+ y*res_3d.x() + z*res_3d.x()*res_3d.y();
 	Vector3f pos;
-	if (voxel_centers)
-		pos = Vector3f{(float)x+0.5f, (float)y+0.5f, (float)z+0.5f}.cwiseQuotient((res_3d).cast<float>());
-	else
-		pos = Vector3f{(float)x, (float)y, (float)z}.cwiseQuotient((res_3d-Vector3i::Ones()).cast<float>());
+	if (voxel_centers) {
+		pos = Vector3f{(float)x + 0.5f, (float)y + 0.5f, (float)z + 0.5f}.cwiseQuotient((res_3d).cast<float>());
+	} else {
+		pos = Vector3f{(float)x, (float)y, (float)z}.cwiseQuotient((res_3d - Vector3i::Ones()).cast<float>());
+	}
+
 	pos = render_aabb_to_local.transpose() * (pos.cwiseProduct(render_aabb.max - render_aabb.min) + render_aabb.min);
 	network_input[i] = { warp_position(pos, train_aabb), warp_direction(ray_dir), warp_dt(MIN_CONE_STEPSIZE()) };
 }
@@ -449,8 +455,11 @@ inline __device__ int mip_from_pos(const Vector3f& pos, uint32_t max_cascade = N
 
 inline __device__ int mip_from_dt(float dt, const Vector3f& pos, uint32_t max_cascade = NERF_CASCADES()-1) {
 	int mip = mip_from_pos(pos, max_cascade);
-	dt *= 2*NERF_GRIDSIZE();
-	if (dt<1.f) return mip;
+	dt *= 2 * NERF_GRIDSIZE();
+	if (dt < 1.0f) {
+		return mip;
+	}
+
 	int exponent;
 	frexpf(dt, &exponent);
 	return min(max_cascade, max(exponent, mip));
@@ -467,15 +476,15 @@ __global__ void generate_grid_samples_nerf_nonuniform(const uint32_t n_elements,
 	// Select grid cell that has density
 	uint32_t idx;
 	for (uint32_t j = 0; j < 10; ++j) {
-		idx = ((i+step*n_elements) * 56924617 + j * 19349663 + 96925573) % (NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE());
-		idx += level * NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE();
+		idx = ((i+step*n_elements) * 56924617 + j * 19349663 + 96925573) % NERF_GRID_N_CELLS();
+		idx += level * NERF_GRID_N_CELLS();
 		if (grid_in[idx] > thresh) {
 			break;
 		}
 	}
 
 	// Random position within that cellq
-	uint32_t pos_idx = idx % (NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE());
+	uint32_t pos_idx = idx % NERF_GRID_N_CELLS();
 
 	uint32_t x = tcnn::morton3D_invert(pos_idx>>0);
 	uint32_t y = tcnn::morton3D_invert(pos_idx>>1);
@@ -495,7 +504,7 @@ __global__ void splat_grid_samples_nerf_max_nearest_neighbor(const uint32_t n_el
 
 	// Current setting: optical thickness of the smallest possible stepsize.
 	// Uncomment for:   optical thickness of the ~expected step size when the observer is in the middle of the scene
-	uint32_t level = 0;//local_idx / (NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE());
+	uint32_t level = 0;//local_idx / NERF_GRID_N_CELLS();
 
 	float mlp = network_to_density(float(network_output[i]), density_activation);
 	float optical_thickness = mlp * scalbnf(MIN_CONE_STEPSIZE(), level);
@@ -2661,7 +2670,7 @@ void Testbed::load_nerf() {
 }
 
 void Testbed::update_density_grid_nerf(float decay, uint32_t n_uniform_density_grid_samples, uint32_t n_nonuniform_density_grid_samples, cudaStream_t stream) {
-	const uint32_t n_elements = NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE() * (m_nerf.max_cascade + 1);
+	const uint32_t n_elements = NERF_GRID_N_CELLS() * (m_nerf.max_cascade + 1);
 
 	m_nerf.density_grid.resize(n_elements);
 
@@ -2730,9 +2739,17 @@ void Testbed::update_density_grid_nerf(float decay, uint32_t n_uniform_density_g
 		);
 		m_nerf.training.density_grid_rng.advance();
 
-		GPUMatrix<network_precision_t, RM> density_matrix(mlp_out, padded_output_width, n_density_grid_samples);
-		GPUMatrix<float> density_grid_position_matrix((float*)density_grid_positions, sizeof(NerfPosition)/sizeof(float), n_density_grid_samples);
-		m_nerf_network->density(stream, density_grid_position_matrix, density_matrix, false);
+		// Evaluate density at the spawned locations in batches.
+		// Otherwise, we can exhaust the maximum index range of cutlass
+		size_t batch_size = NERF_GRID_N_CELLS() * 2;
+
+		for (size_t i = 0; i < n_density_grid_samples; i += batch_size) {
+			batch_size = std::min(batch_size, n_density_grid_samples - i);
+
+			GPUMatrix<network_precision_t, RM> density_matrix(mlp_out + i, padded_output_width, batch_size);
+			GPUMatrix<float> density_grid_position_matrix((float*)(density_grid_positions + i), sizeof(NerfPosition)/sizeof(float), batch_size);
+			m_nerf_network->density(stream, density_grid_position_matrix, density_matrix, false);
+		}
 
 		linear_kernel(splat_grid_samples_nerf_max_nearest_neighbor, 0, stream, n_density_grid_samples, density_grid_indices, mlp_out, density_grid_tmp, m_nerf.rgb_activation, m_nerf.density_activation);
 		linear_kernel(ema_grid_samples_nerf, 0, stream, n_elements, decay, m_nerf.density_grid_ema_step, m_nerf.density_grid.data(), density_grid_tmp);
@@ -2744,7 +2761,7 @@ void Testbed::update_density_grid_nerf(float decay, uint32_t n_uniform_density_g
 }
 
 void Testbed::update_density_grid_mean_and_bitfield(cudaStream_t stream) {
-	const uint32_t n_elements = NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE();
+	const uint32_t n_elements = NERF_GRID_N_CELLS();
 
 	size_t size_including_mips = grid_mip_offset(NERF_CASCADES())/8;
 	m_nerf.density_grid_bitfield.enlarge(size_including_mips);
@@ -2801,9 +2818,9 @@ void Testbed::train_nerf(uint32_t target_batch_size, bool get_loss_scalar, cudaS
 	}
 
 	if (m_nerf.training.include_sharpness_in_error) {
-		size_t n_cells = NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_CASCADES();
+		size_t n_cells = NERF_GRID_N_CELLS() * NERF_CASCADES();
 		if (m_nerf.training.sharpness_grid.size() < n_cells) {
-			m_nerf.training.sharpness_grid.enlarge(NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_CASCADES());
+			m_nerf.training.sharpness_grid.enlarge(NERF_GRID_N_CELLS() * NERF_CASCADES());
 			CUDA_CHECK_THROW(cudaMemsetAsync(m_nerf.training.sharpness_grid.data(), 0, m_nerf.training.sharpness_grid.get_bytes(), stream));
 		}
 
@@ -3039,7 +3056,7 @@ void Testbed::train_nerf(uint32_t target_batch_size, bool get_loss_scalar, cudaS
 
 void Testbed::train_nerf_step(uint32_t target_batch_size, Testbed::NerfCounters& counters, cudaStream_t stream) {
 	const uint32_t padded_output_width = m_network->padded_output_width();
-	const uint32_t max_samples = target_batch_size * 16; // Somewhat of a worst case
+	const uint32_t max_samples = target_batch_size * 4; // Somewhat of a worst case
 	const uint32_t floats_per_coord = sizeof(NerfCoordinate) / sizeof(float) + m_nerf_network->n_extra_dims();
 	const uint32_t extra_stride = m_nerf_network->n_extra_dims() * sizeof(float); // extra stride on top of base NerfCoordinate struct
 
@@ -3296,9 +3313,9 @@ void Testbed::training_prep_nerf(uint32_t batch_size, cudaStream_t stream) {
 	uint32_t n_cascades = m_nerf.max_cascade+1;
 
 	if (m_training_step < 256) {
-		update_density_grid_nerf(alpha, NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE()*n_cascades, 0, stream);
+		update_density_grid_nerf(alpha, NERF_GRID_N_CELLS() * n_cascades, 0, stream);
 	} else {
-		update_density_grid_nerf(alpha, NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE()/4*n_cascades, NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE()/4*n_cascades, stream);
+		update_density_grid_nerf(alpha, NERF_GRID_N_CELLS() / 4 * n_cascades, NERF_GRID_N_CELLS() / 4 * n_cascades, stream);
 	}
 }
 
-- 
GitLab