Skip to content

Commit 8fa1f0c

Browse files
author
Joshua Bradt
committed
Refactored thread pools with base classes.
- Created ThreadPoolBase class - Inherited from it for LockfreeThreadPool and ThreadPool - Removed ParallelMap since we weren't using it
1 parent 3d54895 commit 8fa1f0c

File tree

6 files changed

+115
-152
lines changed

6 files changed

+115
-152
lines changed

include/LockfreeThreadPool.hh

Lines changed: 21 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,69 +9,44 @@
99
#include <utility>
1010
#include <future>
1111

12+
#include "ThreadPoolBase.hh"
1213
#include "concurrentqueue.h"
1314

1415
using namespace std;
1516
using namespace moodycamel;
1617

1718
typedef std::function<void()> WorkType;
1819

19-
class LockfreeThreadPool {
20+
class LockfreeThreadPool : public ThreadPoolBase {
2021
public:
21-
LockfreeThreadPool(uint32_t numthreads);
22-
~LockfreeThreadPool();
2322

24-
void Enqueue(function<void()> task);
25-
void Worker();
26-
void JoinAll();
27-
void Finish();
23+
LockfreeThreadPool(uint32_t num_threads)
24+
{
25+
for (uint32_t i=0; i<num_threads;i++) {
26+
m_workers.emplace_back(&LockfreeThreadPool::Worker, this);
27+
}
28+
}
2829

29-
template <typename T, typename... Params>
30-
void ParallelFor(uint32_t begin, uint32_t end, int32_t n_tasks, T SerialFunction, Params&&... params) {
30+
~LockfreeThreadPool() {
31+
CleanUp();
32+
JoinAll();
33+
}
3134

32-
n_tasks = (n_tasks >= m_nthreads) ? n_tasks : m_nthreads;
33-
int chunk = (end - begin) / n_tasks;
34-
for (int i = 0; i < n_tasks; ++i) {
35-
m_promises.emplace_back();
36-
int mypromise = m_promises.size() - 1;
37-
m_taskQueue.enqueue([=]{
38-
uint32_t threadstart = begin + i*chunk;
39-
uint32_t threadstop = (i == n_tasks - 1) ? end : threadstart + chunk;
40-
for (uint32_t it = threadstart; it < threadstop; ++it) {
41-
SerialFunction(it, params...);
42-
}
43-
m_promises[mypromise].set_value();
44-
});
45-
}
46-
Finish();
35+
void AddTask(WorkType task) override
36+
{
37+
m_taskQueue.enqueue(task);
4738
}
48-
template<typename InputIt, typename T>
49-
void ParallelMap(InputIt begin, InputIt end, InputIt outputBegin, const std::function<T(T)>& func)
39+
40+
void Worker() override
5041
{
51-
int chunkSize = (end - begin) / m_nthreads;
52-
for (int i = 0; i < m_nthreads; i++) {
53-
m_promises.emplace_back();
54-
int mypromise = m_promises.size() - 1;
55-
m_taskQueue.enqueue([=]{
56-
InputIt threadBegin = begin + i*chunkSize;
57-
InputIt threadOutput = outputBegin + i*chunkSize;
58-
InputIt threadEnd = (i == m_nthreads - 1) ? end : threadBegin + chunkSize;
59-
while (threadBegin != threadEnd) {
60-
*(threadOutput++) = func(*(threadBegin++));
61-
}
62-
m_promises[mypromise].set_value();
63-
});
42+
while(true) {
43+
function<void()> work;
44+
if (m_taskQueue.try_dequeue(work)) work();
45+
if (m_stopWorkers) return;
6446
}
65-
Finish();
6647
}
6748

6849
private:
69-
// threads and task queue
70-
int m_nthreads;
71-
vector<thread> m_workers;
72-
vector<promise<void>> m_promises;
73-
bool m_stopWorkers;
74-
7550
ConcurrentQueue<WorkType> m_taskQueue;
7651
};
7752

include/ThreadPool.hh

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,68 +9,55 @@
99
#include <utility>
1010
#include <future>
1111

12+
#include "ThreadPoolBase.hh"
13+
1214
#include "ThreadsafeQueue.hh"
1315

1416
using namespace std;
1517

1618
typedef std::function<void()> WorkType;
1719

18-
class ThreadPool {
20+
class ThreadPool : public ThreadPoolBase {
1921
public:
20-
ThreadPool(uint32_t numthreads);
21-
~ThreadPool();
2222

23-
void Enqueue(function<void()> task);
24-
void Worker();
25-
void JoinAll();
26-
void Finish();
23+
ThreadPool(uint32_t num_threads)
24+
{
25+
for (uint32_t i=0; i<num_threads;i++) {
26+
m_workers.emplace_back(&ThreadPool::Worker, this);
27+
}
28+
}
2729

28-
template <typename T, typename... Params>
29-
void ParallelFor(uint32_t begin, uint32_t end, int32_t n_tasks, T SerialFunction, Params&&... params) {
30+
~ThreadPool() {
31+
CleanUp();
32+
JoinAll();
33+
}
3034

31-
n_tasks = (n_tasks >= m_nthreads) ? n_tasks : m_nthreads;
32-
int chunk = (end - begin) / n_tasks;
33-
for (int i = 0; i < n_tasks; ++i) {
34-
m_promises.emplace_back();
35-
int mypromise = m_promises.size() - 1;
36-
m_taskQueue.push([=]{
37-
uint32_t threadstart = begin + i*chunk;
38-
uint32_t threadstop = (i == n_tasks - 1) ? end : threadstart + chunk;
39-
for (uint32_t it = threadstart; it < threadstop; ++it) {
40-
SerialFunction(it, params...);
41-
}
42-
m_promises[mypromise].set_value();
43-
});
44-
}
45-
Finish();
35+
void AddTask(WorkType task) override
36+
{
37+
m_taskQueue.push(task);
4638
}
47-
template<typename InputIt, typename T>
48-
void ParallelMap(InputIt begin, InputIt end, InputIt outputBegin, const std::function<T(T)>& func)
39+
40+
void CleanUp() override
4941
{
50-
int chunkSize = (end - begin) / m_nthreads;
51-
for (int i = 0; i < m_nthreads; i++) {
52-
m_promises.emplace_back();
53-
int mypromise = m_promises.size() - 1;
54-
m_taskQueue.push([=]{
55-
InputIt threadBegin = begin + i*chunkSize;
56-
InputIt threadOutput = outputBegin + i*chunkSize;
57-
InputIt threadEnd = (i == m_nthreads - 1) ? end : threadBegin + chunkSize;
58-
while (threadBegin != threadEnd) {
59-
*(threadOutput++) = func(*(threadBegin++));
60-
}
61-
m_promises[mypromise].set_value();
62-
});
42+
m_stopWorkers = true;
43+
m_taskQueue.join();
44+
}
45+
46+
void Worker() override {
47+
while(true) {
48+
WorkType work;
49+
try {
50+
work = m_taskQueue.pop();
51+
work();
52+
}
53+
catch (const ThreadsafeQueue<WorkType>::QueueFinished&)
54+
{
55+
return;
56+
}
6357
}
64-
Finish();
6558
}
6659

6760
private:
68-
// threads and task queue
69-
int m_nthreads;
70-
vector<thread> m_workers;
71-
vector<promise<void>> m_promises;
72-
bool m_stopWorkers;
73-
7461
ThreadsafeQueue<WorkType> m_taskQueue;
7562
};
7663

include/ThreadPoolBase.hh

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#ifndef THREADPOOLBASE_H
2+
#define THREADPOOLBASE_H
3+
4+
#include <iostream>
5+
#include <chrono>
6+
#include <thread>
7+
#include <functional>
8+
#include <vector>
9+
#include <utility>
10+
#include <future>
11+
12+
using namespace std;
13+
14+
typedef std::function<void()> WorkType;
15+
16+
class ThreadPoolBase {
17+
public:
18+
ThreadPoolBase() = default;
19+
20+
virtual void AddTask(function<void()> task) = 0;
21+
22+
virtual void Worker() = 0;
23+
virtual void CleanUp() { m_stopWorkers = true; }
24+
void JoinAll();
25+
void Finish();
26+
27+
template <typename T, typename... Params>
28+
void ParallelFor(uint32_t begin, uint32_t end, uint32_t n_tasks, T SerialFunction, Params&&... params) {
29+
30+
n_tasks = (n_tasks >= m_workers.size()) ? n_tasks : m_workers.size();
31+
m_tasksRemaining = n_tasks;
32+
int chunk = (end - begin) / n_tasks;
33+
for (int i = 0; i < n_tasks; ++i) {
34+
AddTask([=]{
35+
uint32_t threadstart = begin + i*chunk;
36+
uint32_t threadstop = (i == n_tasks - 1) ? end : threadstart + chunk;
37+
for (uint32_t it = threadstart; it < threadstop; ++it) {
38+
SerialFunction(it, params...);
39+
}
40+
m_tasksRemaining--;
41+
});
42+
}
43+
Finish();
44+
}
45+
46+
protected:
47+
// threads and task queue
48+
vector<thread> m_workers;
49+
atomic<int> m_tasksRemaining;
50+
bool m_stopWorkers;
51+
};
52+
53+
#endif /* end of include guard: THREADPOOL_H */

src/LockfreeThreadPool.cc

Lines changed: 0 additions & 26 deletions
This file was deleted.

src/ThreadPool.cc

Lines changed: 0 additions & 34 deletions
This file was deleted.

src/ThreadPoolBase.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include "ThreadPoolBase.hh"
2+
3+
void ThreadPoolBase::JoinAll() {
4+
for (auto& worker : m_workers) { worker.join(); }
5+
}
6+
void ThreadPoolBase::Finish() {
7+
while (m_tasksRemaining > 0) {}
8+
}

0 commit comments

Comments
 (0)