diff --git a/include/neural-graphics-primitives/thread_pool.h b/include/neural-graphics-primitives/thread_pool.h index c991f22aee2c54fd012993e85f1cf75e9e33092b..879888306778a6601d1be8ea3a48578a7ab29b5c 100644 --- a/include/neural-graphics-primitives/thread_pool.h +++ b/include/neural-graphics-primitives/thread_pool.h @@ -63,6 +63,7 @@ public: void shutdown_threads(size_t num); void set_n_threads(size_t num); + void wait_until_queue_completed(); void flush_queue(); template <typename Int, typename F> @@ -102,6 +103,7 @@ private: std::deque<std::function<void()>> m_task_queue; std::mutex m_task_queue_mutex; std::condition_variable m_worker_condition; + std::condition_variable m_task_queue_completed_condition; }; NGP_NAMESPACE_END diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp index ec8bcfaa04cedaf9100d231fcd71e7aeb66be7c0..4acf4ce83d06b92dded0129bf3b985ad7736ca0f 100644 --- a/src/thread_pool.cpp +++ b/src/thread_pool.cpp @@ -31,6 +31,7 @@ ThreadPool::ThreadPool(size_t max_num_threads, bool force) { } ThreadPool::~ThreadPool() { + wait_until_queue_completed(); shutdown_threads(m_threads.size()); } @@ -43,7 +44,9 @@ void ThreadPool::start_threads(size_t num) { // look for a work item while (i < m_num_threads && m_task_queue.empty()) { - // if there are none wait for notification + // if there are none, signal that the queue is completed + // and wait for notification of new work items. + m_task_queue_completed_condition.notify_all(); m_worker_condition.wait(lock); } @@ -87,6 +90,11 @@ void ThreadPool::set_n_threads(size_t num) { } } +void ThreadPool::wait_until_queue_completed() { + unique_lock<mutex> lock{m_task_queue_mutex}; + m_task_queue_completed_condition.wait(lock, [this]() { return m_task_queue.empty(); }); +} + void ThreadPool::flush_queue() { lock_guard<mutex> lock{m_task_queue_mutex}; m_task_queue.clear();