diff --git a/ThreadPool.h b/ThreadPool.h deleted file mode 100644 index 0235c8b..0000000 --- a/ThreadPool.h +++ /dev/null @@ -1,93 +0,0 @@ -#ifndef THREAD_POOL_H -#define THREAD_POOL_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { -public: - ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); -private: - // need to keep track of threads so we can join them - std::vector< std::thread > workers; - // the task queue - std::queue< std::function > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) - : stop(false) -{ - for(size_t i = 0;i lock(this->queue_mutex); - while(!this->stop && this->tasks.empty()) - this->condition.wait(lock); - if(this->stop && this->tasks.empty()) - return; - std::function task(this->tasks.front()); - this->tasks.pop(); - lock.unlock(); - task(); - } - } - ); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> -{ - typedef typename std::result_of::type return_type; - - // don't allow enqueueing after stopping the pool - if(stop) - throw std::runtime_error("enqueue on stopped ThreadPool"); - - auto task = std::make_shared< std::packaged_task >( - std::bind(std::forward(f), std::forward(args)...) - ); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - tasks.push([task](){ (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() -{ - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for(size_t i = 0;i +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { +public: + ThreadPool(size_t = std::thread::hardware_concurrency()); + ~ThreadPool(); + + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + + // wait for the completion of all tasks + // (meaning that all worker are waiting for new tasks) + void wait() const; + +private: + // need to keep track of threads so we can join them + std::vector< std::thread > workers; + // the task queue + std::queue< std::function > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; + + // waiting for completion + size_t active_worker; + mutable std::mutex work_done_mutex; + mutable std::condition_variable work_done_condition; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) + : stop(false), + active_worker(threads) +{ + for (size_t i = 0; i < threads; ++i) { + workers.emplace_back( + [this] { + for (;;) + { + std::unique_lock lock(this->queue_mutex); + + --this->active_worker; + + while (!this->stop && this->tasks.empty()) { + this->work_done_condition.notify_one(); // signal that this thread is done + this->condition.wait(lock); // and wait for more tasks + } + + if (this->stop && this->tasks.empty()) + return; + + ++this->active_worker; + + std::function task(this->tasks.front()); + this->tasks.pop(); + lock.unlock(); + + task(); + } + } + ); + } +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) +-> std::future::type> +{ + using return_type = typename std::result_of::type; + + // don't allow enqueueing after stopping the pool + if (stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + tasks.push([task](){ (*task)(); }); + } + condition.notify_one(); + return res; +} + +inline void ThreadPool::wait() const { + std::unique_lock lock(work_done_mutex); + + // wait until all threads are done and tasks are empty + while (!(active_worker == 0 && tasks.empty())) + work_done_condition.wait(lock); +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() +{ + { + std::unique_lock lock(this->queue_mutex); + stop = true; + } + + condition.notify_all(); + for (size_t i = 0; i < workers.size(); ++i) + workers[i].join(); +} + +#endif diff --git a/example.cpp b/example.cpp index 66d6ab7..cca960a 100644 --- a/example.cpp +++ b/example.cpp @@ -2,15 +2,14 @@ #include #include -#include "ThreadPool.h" +#include "ThreadPool.hpp" int main() { - ThreadPool pool(4); - std::vector< std::future > results; + std::vector> results; - for(int i = 0; i < 8; ++i) { + for (int i = 0; i < 8; ++i) { results.push_back( pool.enqueue([i] { std::cout << "hello " << i << std::endl; @@ -19,11 +18,14 @@ int main() return i*i; }) ); - } - - for(size_t i = 0;i