|  | 
|  | 1 | +#ifndef THREAD_POOL_H | 
|  | 2 | +#define THREAD_POOL_H | 
|  | 3 | + | 
|  | 4 | +#include <vector> | 
|  | 5 | +#include <deque> | 
|  | 6 | +#include <memory> | 
|  | 7 | +#include <thread> | 
|  | 8 | +#include <mutex> | 
|  | 9 | +#include <condition_variable> | 
|  | 10 | + | 
|  | 11 | +class ThreadPool; | 
|  | 12 | +  | 
|  | 13 | +// our worker thread objects | 
|  | 14 | +class Worker { | 
|  | 15 | +public: | 
|  | 16 | +    Worker(ThreadPool &s) : pool(s) { } | 
|  | 17 | +    void operator()(); | 
|  | 18 | +private: | 
|  | 19 | +    ThreadPool &pool; | 
|  | 20 | +}; | 
|  | 21 | + | 
|  | 22 | +template<class T> | 
|  | 23 | +class Result { | 
|  | 24 | +    struct ResultImpl { | 
|  | 25 | +        ResultImpl() : value(T()), available(false) { } | 
|  | 26 | +        T value; | 
|  | 27 | +        bool available; | 
|  | 28 | +        std::mutex lock; | 
|  | 29 | +        std::condition_variable cond; | 
|  | 30 | +    }; | 
|  | 31 | +public: | 
|  | 32 | +    Result() : impl(new ResultImpl()) { } | 
|  | 33 | +    bool available() const | 
|  | 34 | +    { | 
|  | 35 | +        std::unique_lock<std::mutex> ul(impl->lock); | 
|  | 36 | +        return impl->available; | 
|  | 37 | +    } | 
|  | 38 | +    void wait() | 
|  | 39 | +    { | 
|  | 40 | +        if(!impl) | 
|  | 41 | +            return; | 
|  | 42 | +        std::unique_lock<std::mutex> ul(impl->lock); | 
|  | 43 | +        if(impl->available) | 
|  | 44 | +            return; | 
|  | 45 | +        impl->cond.wait(ul); | 
|  | 46 | +    } | 
|  | 47 | +    void signal() const | 
|  | 48 | +    { | 
|  | 49 | +        std::unique_lock<std::mutex> ul(impl->lock); | 
|  | 50 | +        impl->available = true; impl->cond.notify_all(); | 
|  | 51 | +    } | 
|  | 52 | +    bool valid() const | 
|  | 53 | +    {  | 
|  | 54 | +        std::unique_lock<std::mutex> ul(impl->lock); | 
|  | 55 | +        return static_cast<bool>(impl); | 
|  | 56 | +    } | 
|  | 57 | + | 
|  | 58 | +    T& get() | 
|  | 59 | +    {  | 
|  | 60 | +        wait();  | 
|  | 61 | +        return impl->value; | 
|  | 62 | +    } | 
|  | 63 | +    void set(T v) const | 
|  | 64 | +    { | 
|  | 65 | +        std::unique_lock<std::mutex> ul(impl->lock);  | 
|  | 66 | +        impl->value = v;  | 
|  | 67 | +    } | 
|  | 68 | +   | 
|  | 69 | +private: | 
|  | 70 | +    std::shared_ptr<ResultImpl> impl; | 
|  | 71 | +}; | 
|  | 72 | + | 
|  | 73 | +template<> | 
|  | 74 | +class Result<void> { | 
|  | 75 | +    struct ResultImpl { | 
|  | 76 | +        ResultImpl() : available(false) {  } | 
|  | 77 | +        bool available; | 
|  | 78 | +        std::mutex lock; | 
|  | 79 | +        std::condition_variable cond; | 
|  | 80 | +    }; | 
|  | 81 | +public: | 
|  | 82 | +    Result() : impl(new ResultImpl()) { } | 
|  | 83 | + | 
|  | 84 | +    bool available() const | 
|  | 85 | +    { | 
|  | 86 | +        std::unique_lock<std::mutex> ul(impl->lock); | 
|  | 87 | +        return impl->available; | 
|  | 88 | +    } | 
|  | 89 | +    void wait() | 
|  | 90 | +    { | 
|  | 91 | +        if(!impl) | 
|  | 92 | +            return; | 
|  | 93 | +        std::unique_lock<std::mutex> ul(impl->lock); | 
|  | 94 | +        if(impl->available) | 
|  | 95 | +            return; | 
|  | 96 | +        impl->cond.wait(ul); | 
|  | 97 | +    } | 
|  | 98 | +    void signal() const | 
|  | 99 | +    { | 
|  | 100 | +        std::unique_lock<std::mutex> ul(impl->lock); | 
|  | 101 | +        impl->available = true; impl->cond.notify_all(); | 
|  | 102 | +    } | 
|  | 103 | +    bool valid() const | 
|  | 104 | +    {  | 
|  | 105 | +        std::unique_lock<std::mutex> ul(impl->lock); | 
|  | 106 | +        return static_cast<bool>(impl); | 
|  | 107 | +    } | 
|  | 108 | +private: | 
|  | 109 | +    std::shared_ptr<ResultImpl> impl; | 
|  | 110 | +}; | 
|  | 111 | + | 
|  | 112 | +// the actual thread pool | 
|  | 113 | +class ThreadPool { | 
|  | 114 | +public: | 
|  | 115 | +    ThreadPool(size_t); | 
|  | 116 | +    template<class T, class F> | 
|  | 117 | +    Result<T> enqueue(F f); | 
|  | 118 | +    ~ThreadPool(); | 
|  | 119 | +private: | 
|  | 120 | +    friend class Worker; | 
|  | 121 | + | 
|  | 122 | +    // need to keep track of threads so we can join them | 
|  | 123 | +    std::vector< std::thread > workers; | 
|  | 124 | +    // the task queue | 
|  | 125 | +    std::deque< std::function<void()> > tasks; | 
|  | 126 | +     | 
|  | 127 | +    // synchronization | 
|  | 128 | +    std::mutex queue_mutex; | 
|  | 129 | +    std::condition_variable condition; | 
|  | 130 | +    bool stop; | 
|  | 131 | +}; | 
|  | 132 | +  | 
|  | 133 | +void Worker::operator()() | 
|  | 134 | +{ | 
|  | 135 | +    std::function<void()> task; | 
|  | 136 | +    while(true) | 
|  | 137 | +    { | 
|  | 138 | +        { | 
|  | 139 | +            std::unique_lock<std::mutex> lock(pool.queue_mutex); | 
|  | 140 | +            while(!pool.stop && pool.tasks.empty()) | 
|  | 141 | +                pool.condition.wait(lock); | 
|  | 142 | +            if(pool.stop) | 
|  | 143 | +                return; | 
|  | 144 | +            task = pool.tasks.front(); | 
|  | 145 | +            pool.tasks.pop_front(); | 
|  | 146 | +        } | 
|  | 147 | +        task(); | 
|  | 148 | +    } | 
|  | 149 | +} | 
|  | 150 | +  | 
|  | 151 | +// the constructor just launches some amount of workers | 
|  | 152 | +ThreadPool::ThreadPool(size_t threads) | 
|  | 153 | +    :   stop(false) | 
|  | 154 | +{ | 
|  | 155 | +    for(size_t i = 0;i<threads;++i) | 
|  | 156 | +        workers.push_back(std::thread(Worker(*this))); | 
|  | 157 | +} | 
|  | 158 | + | 
|  | 159 | +template<class T, class F> | 
|  | 160 | +struct CallAndSet { | 
|  | 161 | +    void operator()(const Result<T> &res, const F f) | 
|  | 162 | +    { | 
|  | 163 | +        res.set(f()); | 
|  | 164 | +        res.signal(); | 
|  | 165 | +    } | 
|  | 166 | +}; | 
|  | 167 | + | 
|  | 168 | +template<class F> | 
|  | 169 | +struct CallAndSet<void,F> { | 
|  | 170 | +    void operator()(const Result<void> &res, const F &f) | 
|  | 171 | +    { | 
|  | 172 | +        f(); | 
|  | 173 | +        res.signal(); | 
|  | 174 | +    } | 
|  | 175 | +}; | 
|  | 176 | +  | 
|  | 177 | +// add new work item to the pool | 
|  | 178 | +template<class T, class F> | 
|  | 179 | +Result<T> ThreadPool::enqueue(F f) | 
|  | 180 | +{ | 
|  | 181 | +    Result<T> res; | 
|  | 182 | +    { | 
|  | 183 | +        std::unique_lock<std::mutex> lock(queue_mutex); | 
|  | 184 | +        tasks.push_back(std::function<void()>( | 
|  | 185 | +        [f,res]() | 
|  | 186 | +        { | 
|  | 187 | +            CallAndSet<T,F>()(res, f); | 
|  | 188 | +        })); | 
|  | 189 | +    } | 
|  | 190 | +    condition.notify_one(); | 
|  | 191 | +    return res; | 
|  | 192 | +} | 
|  | 193 | +  | 
|  | 194 | +// the destructor joins all threads | 
|  | 195 | +ThreadPool::~ThreadPool() | 
|  | 196 | +{ | 
|  | 197 | +    stop = true; | 
|  | 198 | +    condition.notify_all(); | 
|  | 199 | +    for(size_t i = 0;i<workers.size();++i) | 
|  | 200 | +        workers[i].join(); | 
|  | 201 | +} | 
|  | 202 | + | 
|  | 203 | +#endif | 
0 commit comments