diff --git a/include/neural-graphics-primitives/thread_pool.h b/include/neural-graphics-primitives/thread_pool.h index 37f7fd53b18373652e3138e65de96f84c3eda678..c991f22aee2c54fd012993e85f1cf75e9e33092b 100644 --- a/include/neural-graphics-primitives/thread_pool.h +++ b/include/neural-graphics-primitives/thread_pool.h @@ -25,83 +25,83 @@ NGP_NAMESPACE_BEGIN template <typename T> -void waitAll(T&& futures) { - for (auto& f : futures) { - f.get(); - } +void wait_all(T&& futures) { + for (auto& f : futures) { + f.get(); + } } class ThreadPool { public: - ThreadPool(); - ThreadPool(size_t maxNumThreads, bool force = false); - virtual ~ThreadPool(); - - template <class F> - auto enqueueTask(F&& f, bool highPriority = false) -> std::future<std::result_of_t <F()>> { - using return_type = std::result_of_t<F()>; - - auto task = std::make_shared<std::packaged_task<return_type()>>(std::forward<F>(f)); - - auto res = task->get_future(); - - { - std::lock_guard<std::mutex> lock{mTaskQueueMutex}; - - if (highPriority) { - mTaskQueue.emplace_front([task]() { (*task)(); }); - } else { - mTaskQueue.emplace_back([task]() { (*task)(); }); - } - } - - mWorkerCondition.notify_one(); - return res; - } - - void startThreads(size_t num); - void shutdownThreads(size_t num); - void setNThreads(size_t num); - - void flushQueue(); - - template <typename Int, typename F> - void parallelForAsync(Int start, Int end, F body, std::vector<std::future<void>>& futures) { - Int localNumThreads = (Int)mNumThreads; - - Int range = end - start; - Int chunk = (range / localNumThreads) + 1; - - for (Int i = 0; i < localNumThreads; ++i) { - futures.emplace_back(enqueueTask([i, chunk, start, end, body] { - Int innerStart = start + i * chunk; - Int innerEnd = std::min(end, start + (i + 1) * chunk); - for (Int j = innerStart; j < innerEnd; ++j) { - body(j); - } - })); - } - } - - template <typename Int, typename F> - std::vector<std::future<void>> parallelForAsync(Int start, Int end, F body) { - std::vector<std::future<void>> futures; - parallelForAsync(start, end, body, futures); - return futures; - } - - template <typename Int, typename F> - void parallelFor(Int start, Int end, F body) { - waitAll(parallelForAsync(start, end, body)); - } + ThreadPool(); + ThreadPool(size_t maxNum_threads, bool force = false); + virtual ~ThreadPool(); + + template <class F> + auto enqueue_task(F&& f, bool high_priority = false) -> std::future<std::result_of_t <F()>> { + using return_type = std::result_of_t<F()>; + + auto task = std::make_shared<std::packaged_task<return_type()>>(std::forward<F>(f)); + + auto res = task->get_future(); + + { + std::lock_guard<std::mutex> lock{m_task_queue_mutex}; + + if (high_priority) { + m_task_queue.emplace_front([task]() { (*task)(); }); + } else { + m_task_queue.emplace_back([task]() { (*task)(); }); + } + } + + m_worker_condition.notify_one(); + return res; + } + + void start_threads(size_t num); + void shutdown_threads(size_t num); + void set_n_threads(size_t num); + + void flush_queue(); + + template <typename Int, typename F> + void parallel_for_async(Int start, Int end, F body, std::vector<std::future<void>>& futures) { + Int local_num_threads = (Int)m_num_threads; + + Int range = end - start; + Int chunk = (range / local_num_threads) + 1; + + for (Int i = 0; i < local_num_threads; ++i) { + futures.emplace_back(enqueue_task([i, chunk, start, end, body] { + Int inner_start = start + i * chunk; + Int inner_end = std::min(end, start + (i + 1) * chunk); + for (Int j = inner_start; j < inner_end; ++j) { + body(j); + } + })); + } + } + + template <typename Int, typename F> + std::vector<std::future<void>> parallel_for_async(Int start, Int end, F body) { + std::vector<std::future<void>> futures; + parallel_for_async(start, end, body, futures); + return futures; + } + + template <typename Int, typename F> + void parallel_for(Int start, Int end, F body) { + wait_all(parallel_for_async(start, end, body)); + } private: - size_t mNumThreads = 0; - std::vector<std::thread> mThreads; + size_t m_num_threads = 0; + std::vector<std::thread> m_threads; - std::deque<std::function<void()>> mTaskQueue; - std::mutex mTaskQueueMutex; - std::condition_variable mWorkerCondition; + std::deque<std::function<void()>> m_task_queue; + std::mutex m_task_queue_mutex; + std::condition_variable m_worker_condition; }; NGP_NAMESPACE_END diff --git a/include/neural-graphics-primitives/triangle_octree.cuh b/include/neural-graphics-primitives/triangle_octree.cuh index 8ce4b89ff70ee1fdc1114a4e01eb8128e8deb7eb..dbc427b9612f2be00a2c0a140e7f8b221b1baa4a 100644 --- a/include/neural-graphics-primitives/triangle_octree.cuh +++ b/include/neural-graphics-primitives/triangle_octree.cuh @@ -129,7 +129,7 @@ public: int last_n_nodes = n_nodes; n_nodes = node_counter; - pool.parallelFor<int>(last_n_nodes, node_counter, [&](size_t parent_idx) { + pool.parallel_for<int>(last_n_nodes, node_counter, [&](size_t parent_idx) { Vector3i16 child_pos_base = m_nodes[parent_idx].pos * (uint16_t)2; float size = std::scalbnf(1.0f, -depth-1); diff --git a/src/marching_cubes.cu b/src/marching_cubes.cu index edcfb0c13b86ecfae56e37217ee25cb823581d26..d794d63548651c2cbad5d7dfb4fe711a7aa09ee8 100644 --- a/src/marching_cubes.cu +++ b/src/marching_cubes.cu @@ -1048,7 +1048,7 @@ void save_rgba_grid_to_png_sequence(const GPUMemory<Array4f>& rgba, const char* auto progress = tlog::progress(res3d.z()); std::atomic<int> n_saved{0}; - ThreadPool{}.parallelFor<int>(0, res3d.z(), [&](int z) { + ThreadPool{}.parallel_for<int>(0, res3d.z(), [&](int z) { uint8_t* pngpixels = (uint8_t*)malloc(size_t(w) * size_t(h) * 4); uint8_t* dst = pngpixels; for (int y = 0; y < h; ++y) { diff --git a/src/nerf_loader.cu b/src/nerf_loader.cu index 003c423fc3fc4f6f5c314f61fd316384aa8c2eda..dfb08b07dff20d4b1fef7b52c15925b8d718ceba 100644 --- a/src/nerf_loader.cu +++ b/src/nerf_loader.cu @@ -541,7 +541,7 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar } } - if (json.contains("frames") && json["frames"].is_array()) pool.parallelForAsync<size_t>(0, json["frames"].size(), [&progress, &n_loaded, &result, &images, &json, basepath, image_idx, info, rolling_shutter, principal_point, lens, part_after_underscore, fix_premult, enable_depth_loading, enable_ray_loading](size_t i) { + if (json.contains("frames") && json["frames"].is_array()) pool.parallel_for_async<size_t>(0, json["frames"].size(), [&progress, &n_loaded, &result, &images, &json, basepath, image_idx, info, rolling_shutter, principal_point, lens, part_after_underscore, fix_premult, enable_depth_loading, enable_ray_loading](size_t i) { size_t i_img = i + image_idx; auto& frame = json["frames"][i]; LoadedImageInfo& dst = images[i_img]; @@ -706,7 +706,7 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar } - waitAll(futures); + wait_all(futures); tlog::success() << "Loaded " << images.size() << " images after " << tlog::durationToString(progress.duration()); tlog::info() << " cam_aabb=" << cam_aabb; diff --git a/src/python_api.cu b/src/python_api.cu index 8d4b27a2705572e9f0869643611c7be12fac47cd..ca4543c80d0fbb136f278814ab55e5e345ad7536 100644 --- a/src/python_api.cu +++ b/src/python_api.cu @@ -209,8 +209,7 @@ py::array_t<float> Testbed::screenshot(bool linear) const { float* data = (float*)buf.ptr; // Linear, alpha premultiplied, Y flipped - ThreadPool pool; - pool.parallelFor<size_t>(0, m_window_res.y(), [&](size_t y) { + ThreadPool{}.parallel_for<size_t>(0, m_window_res.y(), [&](size_t y) { size_t base = y * m_window_res.x(); size_t base_reverse = (m_window_res.y() - y - 1) * m_window_res.x(); for (uint32_t x = 0; x < m_window_res.x(); ++x) { diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp index 0289db40e1d14c2144cda61f3dcfa227ba1e0158..ec8bcfaa04cedaf9100d231fcd71e7aeb66be7c0 100644 --- a/src/thread_pool.cpp +++ b/src/thread_pool.cpp @@ -23,36 +23,36 @@ using namespace std; ThreadPool::ThreadPool() : ThreadPool{thread::hardware_concurrency()} {} -ThreadPool::ThreadPool(size_t maxNumThreads, bool force) { +ThreadPool::ThreadPool(size_t max_num_threads, bool force) { if (!force) { - maxNumThreads = min((size_t)thread::hardware_concurrency(), maxNumThreads); + max_num_threads = min((size_t)thread::hardware_concurrency(), max_num_threads); } - startThreads(maxNumThreads); + start_threads(max_num_threads); } ThreadPool::~ThreadPool() { - shutdownThreads(mThreads.size()); + shutdown_threads(m_threads.size()); } -void ThreadPool::startThreads(size_t num) { - mNumThreads += num; - for (size_t i = mThreads.size(); i < mNumThreads; ++i) { - mThreads.emplace_back([this, i] { +void ThreadPool::start_threads(size_t num) { + m_num_threads += num; + for (size_t i = m_threads.size(); i < m_num_threads; ++i) { + m_threads.emplace_back([this, i] { while (true) { - unique_lock<mutex> lock{mTaskQueueMutex}; + unique_lock<mutex> lock{m_task_queue_mutex}; // look for a work item - while (i < mNumThreads && mTaskQueue.empty()) { + while (i < m_num_threads && m_task_queue.empty()) { // if there are none wait for notification - mWorkerCondition.wait(lock); + m_worker_condition.wait(lock); } - if (i >= mNumThreads) { + if (i >= m_num_threads) { break; } - function<void()> task{move(mTaskQueue.front())}; - mTaskQueue.pop_front(); + function<void()> task{move(m_task_queue.front())}; + m_task_queue.pop_front(); // Unlock the lock, so we can process the task without blocking other threads lock.unlock(); @@ -63,33 +63,33 @@ void ThreadPool::startThreads(size_t num) { } } -void ThreadPool::shutdownThreads(size_t num) { - auto numToClose = min(num, mNumThreads); +void ThreadPool::shutdown_threads(size_t num) { + auto num_to_close = min(num, m_num_threads); { - lock_guard<mutex> lock{mTaskQueueMutex}; - mNumThreads -= numToClose; + lock_guard<mutex> lock{m_task_queue_mutex}; + m_num_threads -= num_to_close; } // Wake up all the threads to have them quit - mWorkerCondition.notify_all(); - for (auto i = 0u; i < numToClose; ++i) { - mThreads.back().join(); - mThreads.pop_back(); + m_worker_condition.notify_all(); + for (auto i = 0u; i < num_to_close; ++i) { + m_threads.back().join(); + m_threads.pop_back(); } } -void ThreadPool::setNThreads(size_t num) { - if (mNumThreads > num) { - shutdownThreads(mNumThreads - num); - } else if (mNumThreads < num) { - startThreads(num - mNumThreads); +void ThreadPool::set_n_threads(size_t num) { + if (m_num_threads > num) { + shutdown_threads(m_num_threads - num); + } else if (m_num_threads < num) { + start_threads(num - m_num_threads); } } -void ThreadPool::flushQueue() { - lock_guard<mutex> lock{mTaskQueueMutex}; - mTaskQueue.clear(); +void ThreadPool::flush_queue() { + lock_guard<mutex> lock{m_task_queue_mutex}; + m_task_queue.clear(); } NGP_NAMESPACE_END