PAL/NpuYoloV5/06_rknn-cpp-Multithreading-main/include/ThreadPool.hpp
2025-04-28 14:48:28 +08:00

168 lines
4.7 KiB
C++

#ifndef THREADPOOL_H
#define THREADPOOL_H
#include <cassert>
#include <condition_variable>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <unordered_map>
namespace dpool
{
class ThreadPool
{
public:
using MutexGuard = std::lock_guard<std::mutex>;
using UniqueLock = std::unique_lock<std::mutex>;
using Thread = std::thread;
using ThreadID = std::thread::id;
using Task = std::function<void()>;
ThreadPool()
: ThreadPool(Thread::hardware_concurrency())
{
}
explicit ThreadPool(size_t maxThreads)
: quit_(false),
currentThreads_(0),
idleThreads_(0),
maxThreads_(maxThreads)
{
}
// disable the copy operations
ThreadPool(const ThreadPool &) = delete;
ThreadPool &operator=(const ThreadPool &) = delete;
~ThreadPool()
{
{
MutexGuard guard(mutex_);
quit_ = true;
}
cv_.notify_all();
for (auto &elem : threads_)
{
assert(elem.second.joinable());
elem.second.join();
}
}
template <typename Func, typename... Ts>
auto submit(Func &&func, Ts &&...params)
-> std::future<typename std::result_of<Func(Ts...)>::type>
{
auto execute = std::bind(std::forward<Func>(func), std::forward<Ts>(params)...);
using ReturnType = typename std::result_of<Func(Ts...)>::type;
using PackagedTask = std::packaged_task<ReturnType()>;
auto task = std::make_shared<PackagedTask>(std::move(execute));
auto result = task->get_future();
MutexGuard guard(mutex_);
assert(!quit_);
tasks_.emplace([task]()
{ (*task)(); });
if (idleThreads_ > 0)
{
cv_.notify_one();
}
else if (currentThreads_ < maxThreads_)
{
Thread t(&ThreadPool::worker, this);
assert(threads_.find(t.get_id()) == threads_.end());
threads_[t.get_id()] = std::move(t);
++currentThreads_;
}
return result;
}
size_t threadsNum() const
{
MutexGuard guard(mutex_);
return currentThreads_;
}
private:
void worker()
{
while (true)
{
Task task;
{
UniqueLock uniqueLock(mutex_);
++idleThreads_;
auto hasTimedout = !cv_.wait_for(uniqueLock,
std::chrono::seconds(WAIT_SECONDS),
[this]()
{
return quit_ || !tasks_.empty();
});
--idleThreads_;
if (tasks_.empty())
{
if (quit_)
{
--currentThreads_;
return;
}
if (hasTimedout)
{
--currentThreads_;
joinFinishedThreads();
finishedThreadIDs_.emplace(std::this_thread::get_id());
return;
}
}
task = std::move(tasks_.front());
tasks_.pop();
}
task();
}
}
void joinFinishedThreads()
{
while (!finishedThreadIDs_.empty())
{
auto id = std::move(finishedThreadIDs_.front());
finishedThreadIDs_.pop();
auto iter = threads_.find(id);
assert(iter != threads_.end());
assert(iter->second.joinable());
iter->second.join();
threads_.erase(iter);
}
}
static constexpr size_t WAIT_SECONDS = 2;
bool quit_;
size_t currentThreads_;
size_t idleThreads_;
size_t maxThreads_;
mutable std::mutex mutex_;
std::condition_variable cv_;
std::queue<Task> tasks_;
std::queue<ThreadID> finishedThreadIDs_;
std::unordered_map<ThreadID, Thread> threads_;
};
constexpr size_t ThreadPool::WAIT_SECONDS;
} // namespace dpool
#endif /* THREADPOOL_H */