From 324a19e6de43f8b51ea14de35a63ba89e8ce733c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thomas=20M=C3=BCller?= <tmueller@nvidia.com>
Date: Tue, 7 Feb 2023 09:56:12 +0100
Subject: [PATCH] ThreadPool: wait until all tasks are completed upon
 destruction

---
 include/neural-graphics-primitives/thread_pool.h |  2 ++
 src/thread_pool.cpp                              | 10 +++++++++-
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/include/neural-graphics-primitives/thread_pool.h b/include/neural-graphics-primitives/thread_pool.h
index c991f22..8798883 100644
--- a/include/neural-graphics-primitives/thread_pool.h
+++ b/include/neural-graphics-primitives/thread_pool.h
@@ -63,6 +63,7 @@ public:
 	void shutdown_threads(size_t num);
 	void set_n_threads(size_t num);
 
+	void wait_until_queue_completed();
 	void flush_queue();
 
 	template <typename Int, typename F>
@@ -102,6 +103,7 @@ private:
 	std::deque<std::function<void()>> m_task_queue;
 	std::mutex m_task_queue_mutex;
 	std::condition_variable m_worker_condition;
+	std::condition_variable m_task_queue_completed_condition;
 };
 
 NGP_NAMESPACE_END
diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp
index ec8bcfa..4acf4ce 100644
--- a/src/thread_pool.cpp
+++ b/src/thread_pool.cpp
@@ -31,6 +31,7 @@ ThreadPool::ThreadPool(size_t max_num_threads, bool force) {
 }
 
 ThreadPool::~ThreadPool() {
+	wait_until_queue_completed();
 	shutdown_threads(m_threads.size());
 }
 
@@ -43,7 +44,9 @@ void ThreadPool::start_threads(size_t num) {
 
 				// look for a work item
 				while (i < m_num_threads && m_task_queue.empty()) {
-					// if there are none wait for notification
+					// if there are none, signal that the queue is completed
+					// and wait for notification of new work items.
+					m_task_queue_completed_condition.notify_all();
 					m_worker_condition.wait(lock);
 				}
 
@@ -87,6 +90,11 @@ void ThreadPool::set_n_threads(size_t num) {
 	}
 }
 
+void ThreadPool::wait_until_queue_completed() {
+	unique_lock<mutex> lock{m_task_queue_mutex};
+	m_task_queue_completed_condition.wait(lock, [this]() { return m_task_queue.empty(); });
+}
+
 void ThreadPool::flush_queue() {
 	lock_guard<mutex> lock{m_task_queue_mutex};
 	m_task_queue.clear();
-- 
GitLab