From a2ba0e63f74ca41f43ffcc4cad827aa2318da211 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thomas=20M=C3=BCller?= <tmueller@nvidia.com>
Date: Wed, 14 Dec 2022 09:36:30 +0100
Subject: [PATCH] `ThreadPool` simplifications

---
 .../neural-graphics-primitives/thread_pool.h  | 13 +------
 src/thread_pool.cpp                           | 35 +++----------------
 2 files changed, 6 insertions(+), 42 deletions(-)

diff --git a/include/neural-graphics-primitives/thread_pool.h b/include/neural-graphics-primitives/thread_pool.h
index 47e9d7b..37f7fd5 100644
--- a/include/neural-graphics-primitives/thread_pool.h
+++ b/include/neural-graphics-primitives/thread_pool.h
@@ -41,8 +41,6 @@ public:
     auto enqueueTask(F&& f, bool highPriority = false) -> std::future<std::result_of_t <F()>> {
         using return_type = std::result_of_t<F()>;
 
-        ++mNumTasksInSystem;
-
         auto task = std::make_shared<std::packaged_task<return_type()>>(std::forward<F>(f));
 
         auto res = task->get_future();
@@ -63,13 +61,8 @@ public:
 
     void startThreads(size_t num);
     void shutdownThreads(size_t num);
+	void setNThreads(size_t num);
 
-    size_t numTasksInSystem() const {
-        return mNumTasksInSystem;
-    }
-
-    void waitUntilFinished();
-    void waitUntilFinishedFor(const std::chrono::microseconds Duration);
     void flushQueue();
 
     template <typename Int, typename F>
@@ -109,10 +102,6 @@ private:
     std::deque<std::function<void()>> mTaskQueue;
     std::mutex mTaskQueueMutex;
     std::condition_variable mWorkerCondition;
-
-    std::atomic<size_t> mNumTasksInSystem;
-    std::mutex mSystemBusyMutex;
-    std::condition_variable mSystemBusyCondition;
 };
 
 NGP_NAMESPACE_END
diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp
index 17551b1..0289db4 100644
--- a/src/thread_pool.cpp
+++ b/src/thread_pool.cpp
@@ -28,7 +28,6 @@ ThreadPool::ThreadPool(size_t maxNumThreads, bool force) {
 		maxNumThreads = min((size_t)thread::hardware_concurrency(), maxNumThreads);
 	}
 	startThreads(maxNumThreads);
-	mNumTasksInSystem.store(0);
 }
 
 ThreadPool::~ThreadPool() {
@@ -59,16 +58,6 @@ void ThreadPool::startThreads(size_t num) {
 				lock.unlock();
 
 				task();
-
-				mNumTasksInSystem--;
-
-				{
-					unique_lock<mutex> localLock{mSystemBusyMutex};
-
-					if (mNumTasksInSystem == 0) {
-						mSystemBusyCondition.notify_all();
-					}
-				}
 			}
 		});
 	}
@@ -90,30 +79,16 @@ void ThreadPool::shutdownThreads(size_t num) {
 	}
 }
 
-void ThreadPool::waitUntilFinished() {
-	unique_lock<mutex> lock{mSystemBusyMutex};
-
-	if (mNumTasksInSystem == 0) {
-		return;
-	}
-
-	mSystemBusyCondition.wait(lock);
-}
-
-void ThreadPool::waitUntilFinishedFor(const chrono::microseconds Duration) {
-	unique_lock<mutex> lock{mSystemBusyMutex};
-
-	if (mNumTasksInSystem == 0) {
-		return;
+void ThreadPool::setNThreads(size_t num) {
+	if (mNumThreads > num) {
+		shutdownThreads(mNumThreads - num);
+	} else if (mNumThreads < num) {
+		startThreads(num - mNumThreads);
 	}
-
-	mSystemBusyCondition.wait_for(lock, Duration);
 }
 
 void ThreadPool::flushQueue() {
 	lock_guard<mutex> lock{mTaskQueueMutex};
-
-	mNumTasksInSystem -= mTaskQueue.size();
 	mTaskQueue.clear();
 }
 
-- 
GitLab