From 1bbfb0ca1d199586981b3a8d5bf538c006864ae9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20M=C3=BCller?= <thomas94@gmx.net> Date: Tue, 10 Jan 2023 10:39:19 +0100 Subject: [PATCH] Work around internal CUTLASS error upon large batch sizes --- include/neural-graphics-primitives/nerf.h | 4 ++ src/testbed_nerf.cu | 71 ++++++++++++++--------- 2 files changed, 48 insertions(+), 27 deletions(-) diff --git a/include/neural-graphics-primitives/nerf.h b/include/neural-graphics-primitives/nerf.h index 27e3237..5cd2d2b 100644 --- a/include/neural-graphics-primitives/nerf.h +++ b/include/neural-graphics-primitives/nerf.h @@ -25,6 +25,10 @@ inline constexpr __device__ uint32_t NERF_GRIDSIZE() { return 128; } +inline constexpr __device__ uint32_t NERF_GRID_N_CELLS() { + return NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE(); +} + struct NerfPayload { Eigen::Vector3f origin; Eigen::Vector3f dir; diff --git a/src/testbed_nerf.cu b/src/testbed_nerf.cu index 8f116c8..1b4624d 100644 --- a/src/testbed_nerf.cu +++ b/src/testbed_nerf.cu @@ -75,7 +75,7 @@ Testbed::NetworkDims Testbed::network_dims_nerf() const { } inline __host__ __device__ uint32_t grid_mip_offset(uint32_t mip) { - return (NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE()) * mip; + return NERF_GRID_N_CELLS() * mip; } inline __host__ __device__ float calc_cone_angle(float cosine, const Eigen::Vector2f& focal_length, float cone_angle_constant) { @@ -371,8 +371,8 @@ __global__ void mark_untrained_density_grid(const uint32_t n_elements, float* _ const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; if (i >= n_elements) return; - uint32_t level = i / (NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE()); - uint32_t pos_idx = i % (NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE()); + uint32_t level = i / NERF_GRID_N_CELLS(); + uint32_t pos_idx = i % NERF_GRID_N_CELLS(); uint32_t x = tcnn::morton3D_invert(pos_idx>>0); uint32_t y = tcnn::morton3D_invert(pos_idx>>1); @@ -414,10 +414,12 @@ __global__ void generate_grid_samples_nerf_uniform(Eigen::Vector3i res_3d, const uint32_t x = threadIdx.x + blockIdx.x * blockDim.x; uint32_t y = threadIdx.y + blockIdx.y * blockDim.y; uint32_t z = threadIdx.z + blockIdx.z * blockDim.z; - if (x>=res_3d.x() || y>=res_3d.y() || z>=res_3d.z()) + if (x >= res_3d.x() || y >= res_3d.y() || z >= res_3d.z()) { return; - uint32_t i = x+ y*res_3d.x() + z*res_3d.x()*res_3d.y(); - Vector3f pos = Vector3f{(float)x, (float)y, (float)z}.cwiseQuotient((res_3d-Vector3i::Ones()).cast<float>()); + } + + uint32_t i = x + y * res_3d.x() + z * res_3d.x() * res_3d.y(); + Vector3f pos = Vector3f{(float)x, (float)y, (float)z}.cwiseQuotient((res_3d - Vector3i::Ones()).cast<float>()); pos = render_aabb_to_local.transpose() * (pos.cwiseProduct(render_aabb.max - render_aabb.min) + render_aabb.min); out[i] = { warp_position(pos, train_aabb), warp_dt(MIN_CONE_STEPSIZE()) }; } @@ -428,14 +430,18 @@ __global__ void generate_grid_samples_nerf_uniform_dir(Eigen::Vector3i res_3d, c uint32_t x = threadIdx.x + blockIdx.x * blockDim.x; uint32_t y = threadIdx.y + blockIdx.y * blockDim.y; uint32_t z = threadIdx.z + blockIdx.z * blockDim.z; - if (x>=res_3d.x() || y>=res_3d.y() || z>=res_3d.z()) + if (x >= res_3d.x() || y >= res_3d.y() || z >= res_3d.z()) { return; + } + uint32_t i = x+ y*res_3d.x() + z*res_3d.x()*res_3d.y(); Vector3f pos; - if (voxel_centers) - pos = Vector3f{(float)x+0.5f, (float)y+0.5f, (float)z+0.5f}.cwiseQuotient((res_3d).cast<float>()); - else - pos = Vector3f{(float)x, (float)y, (float)z}.cwiseQuotient((res_3d-Vector3i::Ones()).cast<float>()); + if (voxel_centers) { + pos = Vector3f{(float)x + 0.5f, (float)y + 0.5f, (float)z + 0.5f}.cwiseQuotient((res_3d).cast<float>()); + } else { + pos = Vector3f{(float)x, (float)y, (float)z}.cwiseQuotient((res_3d - Vector3i::Ones()).cast<float>()); + } + pos = render_aabb_to_local.transpose() * (pos.cwiseProduct(render_aabb.max - render_aabb.min) + render_aabb.min); network_input[i] = { warp_position(pos, train_aabb), warp_direction(ray_dir), warp_dt(MIN_CONE_STEPSIZE()) }; } @@ -449,8 +455,11 @@ inline __device__ int mip_from_pos(const Vector3f& pos, uint32_t max_cascade = N inline __device__ int mip_from_dt(float dt, const Vector3f& pos, uint32_t max_cascade = NERF_CASCADES()-1) { int mip = mip_from_pos(pos, max_cascade); - dt *= 2*NERF_GRIDSIZE(); - if (dt<1.f) return mip; + dt *= 2 * NERF_GRIDSIZE(); + if (dt < 1.0f) { + return mip; + } + int exponent; frexpf(dt, &exponent); return min(max_cascade, max(exponent, mip)); @@ -467,15 +476,15 @@ __global__ void generate_grid_samples_nerf_nonuniform(const uint32_t n_elements, // Select grid cell that has density uint32_t idx; for (uint32_t j = 0; j < 10; ++j) { - idx = ((i+step*n_elements) * 56924617 + j * 19349663 + 96925573) % (NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE()); - idx += level * NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE(); + idx = ((i+step*n_elements) * 56924617 + j * 19349663 + 96925573) % NERF_GRID_N_CELLS(); + idx += level * NERF_GRID_N_CELLS(); if (grid_in[idx] > thresh) { break; } } // Random position within that cellq - uint32_t pos_idx = idx % (NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE()); + uint32_t pos_idx = idx % NERF_GRID_N_CELLS(); uint32_t x = tcnn::morton3D_invert(pos_idx>>0); uint32_t y = tcnn::morton3D_invert(pos_idx>>1); @@ -495,7 +504,7 @@ __global__ void splat_grid_samples_nerf_max_nearest_neighbor(const uint32_t n_el // Current setting: optical thickness of the smallest possible stepsize. // Uncomment for: optical thickness of the ~expected step size when the observer is in the middle of the scene - uint32_t level = 0;//local_idx / (NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE()); + uint32_t level = 0;//local_idx / NERF_GRID_N_CELLS(); float mlp = network_to_density(float(network_output[i]), density_activation); float optical_thickness = mlp * scalbnf(MIN_CONE_STEPSIZE(), level); @@ -2661,7 +2670,7 @@ void Testbed::load_nerf() { } void Testbed::update_density_grid_nerf(float decay, uint32_t n_uniform_density_grid_samples, uint32_t n_nonuniform_density_grid_samples, cudaStream_t stream) { - const uint32_t n_elements = NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE() * (m_nerf.max_cascade + 1); + const uint32_t n_elements = NERF_GRID_N_CELLS() * (m_nerf.max_cascade + 1); m_nerf.density_grid.resize(n_elements); @@ -2730,9 +2739,17 @@ void Testbed::update_density_grid_nerf(float decay, uint32_t n_uniform_density_g ); m_nerf.training.density_grid_rng.advance(); - GPUMatrix<network_precision_t, RM> density_matrix(mlp_out, padded_output_width, n_density_grid_samples); - GPUMatrix<float> density_grid_position_matrix((float*)density_grid_positions, sizeof(NerfPosition)/sizeof(float), n_density_grid_samples); - m_nerf_network->density(stream, density_grid_position_matrix, density_matrix, false); + // Evaluate density at the spawned locations in batches. + // Otherwise, we can exhaust the maximum index range of cutlass + size_t batch_size = NERF_GRID_N_CELLS() * 2; + + for (size_t i = 0; i < n_density_grid_samples; i += batch_size) { + batch_size = std::min(batch_size, n_density_grid_samples - i); + + GPUMatrix<network_precision_t, RM> density_matrix(mlp_out + i, padded_output_width, batch_size); + GPUMatrix<float> density_grid_position_matrix((float*)(density_grid_positions + i), sizeof(NerfPosition)/sizeof(float), batch_size); + m_nerf_network->density(stream, density_grid_position_matrix, density_matrix, false); + } linear_kernel(splat_grid_samples_nerf_max_nearest_neighbor, 0, stream, n_density_grid_samples, density_grid_indices, mlp_out, density_grid_tmp, m_nerf.rgb_activation, m_nerf.density_activation); linear_kernel(ema_grid_samples_nerf, 0, stream, n_elements, decay, m_nerf.density_grid_ema_step, m_nerf.density_grid.data(), density_grid_tmp); @@ -2744,7 +2761,7 @@ void Testbed::update_density_grid_nerf(float decay, uint32_t n_uniform_density_g } void Testbed::update_density_grid_mean_and_bitfield(cudaStream_t stream) { - const uint32_t n_elements = NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE(); + const uint32_t n_elements = NERF_GRID_N_CELLS(); size_t size_including_mips = grid_mip_offset(NERF_CASCADES())/8; m_nerf.density_grid_bitfield.enlarge(size_including_mips); @@ -2801,9 +2818,9 @@ void Testbed::train_nerf(uint32_t target_batch_size, bool get_loss_scalar, cudaS } if (m_nerf.training.include_sharpness_in_error) { - size_t n_cells = NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_CASCADES(); + size_t n_cells = NERF_GRID_N_CELLS() * NERF_CASCADES(); if (m_nerf.training.sharpness_grid.size() < n_cells) { - m_nerf.training.sharpness_grid.enlarge(NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_CASCADES()); + m_nerf.training.sharpness_grid.enlarge(NERF_GRID_N_CELLS() * NERF_CASCADES()); CUDA_CHECK_THROW(cudaMemsetAsync(m_nerf.training.sharpness_grid.data(), 0, m_nerf.training.sharpness_grid.get_bytes(), stream)); } @@ -3039,7 +3056,7 @@ void Testbed::train_nerf(uint32_t target_batch_size, bool get_loss_scalar, cudaS void Testbed::train_nerf_step(uint32_t target_batch_size, Testbed::NerfCounters& counters, cudaStream_t stream) { const uint32_t padded_output_width = m_network->padded_output_width(); - const uint32_t max_samples = target_batch_size * 16; // Somewhat of a worst case + const uint32_t max_samples = target_batch_size * 4; // Somewhat of a worst case const uint32_t floats_per_coord = sizeof(NerfCoordinate) / sizeof(float) + m_nerf_network->n_extra_dims(); const uint32_t extra_stride = m_nerf_network->n_extra_dims() * sizeof(float); // extra stride on top of base NerfCoordinate struct @@ -3296,9 +3313,9 @@ void Testbed::training_prep_nerf(uint32_t batch_size, cudaStream_t stream) { uint32_t n_cascades = m_nerf.max_cascade+1; if (m_training_step < 256) { - update_density_grid_nerf(alpha, NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE()*n_cascades, 0, stream); + update_density_grid_nerf(alpha, NERF_GRID_N_CELLS() * n_cascades, 0, stream); } else { - update_density_grid_nerf(alpha, NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE()/4*n_cascades, NERF_GRIDSIZE()*NERF_GRIDSIZE()*NERF_GRIDSIZE()/4*n_cascades, stream); + update_density_grid_nerf(alpha, NERF_GRID_N_CELLS() / 4 * n_cascades, NERF_GRID_N_CELLS() / 4 * n_cascades, stream); } } -- GitLab