Skip to content

Commit 08f21d8

Browse files
author
Chris Sullivan
committed
Merge pull request #1 from ExtremeScale/master
Adding ParallelFor templated over any callable type with variadic templated arguments. Credits: C. Sullivan and J. Bradt
2 parents b77bbe0 + cdcced2 commit 08f21d8

File tree

11 files changed

+10712
-67
lines changed

11 files changed

+10712
-67
lines changed

CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
project(ThreadPool)
2+
cmake_minimum_required(VERSION 3.1)
3+
4+
set(CMAKE_CXX_STANDARD 11)
5+
6+
include_directories(include)
7+
add_definitions(-Wall -Wextra -pedantic)
8+
9+
add_library(ThreadPool SHARED src/ThreadPool.cc)
10+
install(TARGETS ThreadPool LIBRARY DESTINATION lib)
11+
install(FILES include/ThreadPool.hh DESTINATION include)
12+
13+
add_executable(test_ThreadPool test/test_ParallelMap.cc test/catch_main.cc)
14+
target_link_libraries(test_ThreadPool ThreadPool)

README.md

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,25 @@ A simple std::thread based thread pool with a parallel-for loop implementation.
44

55
Example of how to perform STREAM benchmark scale operation in parallel:
66
```C++
7-
#include "ThreadPool.hh"
8-
9-
static double* a;
10-
static double* b;
11-
12-
using namespace std;
7+
void scale(int i, double* a, double* b) {
8+
a[i]=4*b[i];
9+
}
1310

1411
int main () {
1512

16-
ThreadPool pool(4);
13+
ThreadPool pool(8);
1714

18-
int N = 1e7;
19-
a = (double*)calloc(N,sizeof(double));
20-
b = (double*)calloc(N,sizeof(double));
15+
int N = 1e9;
16+
auto a = (double*)calloc(N,sizeof(double));
17+
auto b = (double*)calloc(N,sizeof(double));
2118
for (int i=0; i<N; i++) { b[i] = i; }
2219

23-
SERIAL_OPERATION(scale, a[i]=4*b[i]);
24-
pool.ParallelFor<scale>(0,N);
20+
{
21+
Timer timer([&](int elapsed) { cout << elapsed*1e-6 << " ms\n"; });
22+
pool.ParallelFor(0,N,scale,a,b);
23+
}
24+
2525

26-
cin.get();
2726
return 0;
2827
}
2928

default.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ CFLAGS_EXTRA =
6464

6565
# Mandatory arguments to the C++ compiler. These arguments will be
6666
# passed even if CXXFLAGS has been overridden by command-line arguments.
67-
CXXFLAGS_EXTRA = -std=c++11
67+
CXXFLAGS_EXTRA = -std=c++1y
6868

6969
# Mandatory arguments to the linker, before the listing of object
7070
# files. These arguments will be passed even if LDFLAGS has been

example.cc

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,54 @@
11

22
#include "ThreadPool.hh"
3+
#include "Timer.hh"
4+
5+
#include <iostream>
36

4-
static double* a;
5-
static double* b;
67

78
using namespace std;
89

10+
11+
void scale(int i, double* a, double* b) {
12+
a[i]=4*b[i];
13+
}
14+
915
int main () {
1016

11-
ThreadPool pool(4);
17+
ThreadPool pool(8);
1218

13-
int N = 1e7;
14-
a = (double*)calloc(N,sizeof(double));
15-
b = (double*)calloc(N,sizeof(double));
19+
int N = 1e9;
20+
auto a = (double*)calloc(N,sizeof(double));
21+
auto b = (double*)calloc(N,sizeof(double));
1622
for (int i=0; i<N; i++) { b[i] = i; }
1723

18-
SERIAL_OPERATION(scale, a[i]=4*b[i]);
19-
pool.ParallelFor<scale>(0,N);
2024

21-
cin.get();
25+
// cold start for timing purposes
26+
pool.ParallelFor(0,N,scale,a,b);
27+
28+
int ntrials = 10;
29+
double tperformance = 0.0;
30+
cout << "Callable: c-function pointer\n";
31+
for (int i=0; i<ntrials; i++)
32+
{
33+
Timer timer([&](int elapsed){
34+
cout << "Trial " << i << ": "<< elapsed*1e-6 << " ms\n";
35+
tperformance+=elapsed;
36+
});
37+
pool.ParallelFor(0,N,scale,a,b);
38+
}
39+
cout << "Average: " << tperformance*1e-6 / ntrials << " ms\n\n";
40+
41+
tperformance = 0.0;
42+
cout << "Callable: lambda function (without capture) \n";
43+
for (int i=0; i<ntrials; i++)
44+
{
45+
Timer timer([&](int elapsed){
46+
cout << "Trial " << i << ": "<< elapsed*1e-6 << " ms\n";
47+
tperformance+=elapsed;
48+
});
49+
pool.ParallelFor(0,N,[](int k, double* a, double* b) {return a[k] = 4*b[k];},a,b);
50+
}
51+
cout << "Average: " << tperformance*1e-6 / ntrials << " ms\n\n";
52+
2253
return 0;
2354
}

include/ThreadPool.hh

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1-
#include <thread>
1+
#ifndef THREADPOOL_H
2+
#define THREADPOOL_H
3+
24
#include <iostream>
3-
#include <string>
45
#include <chrono>
56
#include <thread>
67
#include <functional>
78
#include <vector>
8-
#include <queue>
9-
#include <mutex>
10-
#include <condition_variable>
9+
#include <utility>
10+
#include <future>
11+
12+
#include "ThreadsafeQueue.hh"
1113

1214
using namespace std;
1315

16+
typedef std::function<void()> WorkType;
17+
1418
class ThreadPool {
1519
public:
1620
ThreadPool(uint32_t numthreads);
@@ -19,34 +23,54 @@ public:
1923
void Enqueue(function<void()> task);
2024
void Worker();
2125
void JoinAll();
26+
void Finish();
2227

23-
template <typename ClassFunction>
24-
void ParallelFor(uint32_t begin, uint32_t end) {
28+
template <typename T, typename... Params>
29+
void ParallelFor(uint32_t begin, uint32_t end, T SerialFunction, Params&&... params) {
2530

2631
int chunk = (end - begin) / m_nthreads;
2732
for (int i = 0; i < m_nthreads; ++i) {
28-
Enqueue([=]{
29-
int threadstart = begin + i*chunk;
30-
int threadstop = (i == m_nthreads - 1) ? end : threadstart + chunk;
31-
for (int it = threadstart; it < threadstop; ++it) {
32-
ClassFunction::Serial(it);
33+
m_promises.emplace_back();
34+
int mypromise = m_promises.size() - 1;
35+
m_taskQueue.push([=]{
36+
uint32_t threadstart = begin + i*chunk;
37+
uint32_t threadstop = (i == m_nthreads - 1) ? end : threadstart + chunk;
38+
for (uint32_t it = threadstart; it < threadstop; ++it) {
39+
SerialFunction(it, params...);
3340
}
41+
m_promises[mypromise].set_value();
3442
});
3543
}
44+
Finish();
45+
}
46+
template<typename InputIt, typename T>
47+
void ParallelMap(InputIt begin, InputIt end, InputIt outputBegin, const std::function<T(T)>& func)
48+
{
49+
int chunkSize = (end - begin) / m_nthreads;
50+
for (int i = 0; i < m_nthreads; i++) {
51+
m_promises.emplace_back();
52+
int mypromise = m_promises.size() - 1;
53+
m_taskQueue.push([=]{
54+
InputIt threadBegin = begin + i*chunkSize;
55+
InputIt threadOutput = outputBegin + i*chunkSize;
56+
InputIt threadEnd = (i == m_nthreads - 1) ? end : threadBegin + chunkSize;
57+
while (threadBegin != threadEnd) {
58+
*(threadOutput++) = func(*(threadBegin++));
59+
}
60+
m_promises[mypromise].set_value();
61+
});
62+
}
63+
Finish();
3664
}
3765

3866
private:
3967
// threads and task queue
4068
int m_nthreads;
4169
vector<thread> m_workers;
42-
queue<function<void()>> m_tasks;
43-
44-
// thread synchronization members
45-
bool stop_workers;
46-
condition_variable m_condition;
47-
mutable mutex m_taskmutex;
70+
vector<promise<void>> m_promises;
71+
bool m_stopWorkers;
4872

73+
ThreadsafeQueue<WorkType> m_taskQueue;
4974
};
5075

51-
// Macro to define a static class function which can be called via ThreadPool::ParallelFor<T>
52-
#define SERIAL_OPERATION(name, function_kernal) class name { public: static void Serial(const int& i) { function_kernal; } };
76+
#endif /* end of include guard: THREADPOOL_H */

include/ThreadsafeQueue.hh

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#ifndef THREADSAFEQUEUE_H
2+
#define THREADSAFEQUEUE_H
3+
4+
#include <thread>
5+
#include <queue>
6+
#include <mutex>
7+
#include <condition_variable>
8+
9+
template<typename T>
10+
class ThreadsafeQueue
11+
{
12+
public:
13+
ThreadsafeQueue() : m_finished(false), m_maxSize(20) {}
14+
ThreadsafeQueue(const size_t maxSize) : m_finished(false), m_maxSize(maxSize) {}
15+
ThreadsafeQueue(const ThreadsafeQueue&) = delete;
16+
ThreadsafeQueue(ThreadsafeQueue&&) = delete;
17+
18+
void push(const T& task)
19+
{
20+
std::unique_lock<std::mutex> lock(m_qmtx);
21+
m_cond.wait(lock, [this]{ return m_contents.size() < m_maxSize; });
22+
m_contents.push(task);
23+
m_cond.notify_all();
24+
}
25+
26+
void push(T&& task)
27+
{
28+
std::unique_lock<std::mutex> lock(m_qmtx);
29+
m_cond.wait(lock, [this]{ return m_contents.size() < m_maxSize; });
30+
m_contents.push(std::move(task));
31+
m_cond.notify_all();
32+
}
33+
34+
T pop()
35+
{
36+
std::unique_lock<std::mutex> lock(m_qmtx);
37+
m_cond.wait(lock, [this]{ return !m_contents.empty() || m_finished; });
38+
if (m_finished) throw QueueFinished();
39+
T item = std::move(m_contents.front());
40+
m_contents.pop();
41+
m_cond.notify_all(); // I'm not sure if this can be called before the return...
42+
return item;
43+
}
44+
45+
void join()
46+
{
47+
std::unique_lock<std::mutex> lock(m_qmtx);
48+
m_cond.wait(lock, [this]{ return m_contents.empty(); });
49+
m_finished = true;
50+
m_cond.notify_all();
51+
}
52+
53+
class QueueFinished : public std::exception
54+
{
55+
public:
56+
virtual const char* what() const noexcept { return "Queue has been joined"; };
57+
};
58+
59+
private:
60+
std::queue<T> m_contents;
61+
std::mutex m_qmtx;
62+
std::condition_variable m_cond;
63+
bool m_finished;
64+
size_t m_maxSize;
65+
};
66+
67+
#endif /* end of include guard: THREADSAFEQUEUE_H */

include/Timer.hh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef TIMER_H
2+
#define TIMER_H
3+
#include <functional>
4+
#include <chrono>
5+
6+
using namespace std;
7+
8+
struct Timer {
9+
Timer(function<void(int)> callback)
10+
: callback(callback)
11+
, t0(chrono::high_resolution_clock::now()) { ; }
12+
~Timer(void) {
13+
auto t1 = chrono::high_resolution_clock::now();
14+
auto elapsed = chrono::duration_cast<chrono::nanoseconds>(t1-t0).count();
15+
callback(elapsed);
16+
}
17+
function<void(int)> callback;
18+
chrono::high_resolution_clock::time_point t0;
19+
20+
};
21+
22+
#endif

src/ThreadPool.cc

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,33 @@
11
#include "ThreadPool.hh"
2-
ThreadPool::ThreadPool(uint32_t numthreads) : m_nthreads(numthreads), stop_workers(false) {
32

3+
ThreadPool::ThreadPool(uint32_t numthreads) : m_nthreads(numthreads), m_stopWorkers(false) {
44
for (uint32_t i=0; i<numthreads;i++) {
55
m_workers.emplace_back(&ThreadPool::Worker, this);
66
}
77
}
88
ThreadPool::~ThreadPool() {
9-
10-
stop_workers = true;
11-
m_condition.notify_all();
9+
m_stopWorkers = true;
10+
m_taskQueue.join();
1211
JoinAll();
1312
}
14-
void ThreadPool::Enqueue(function<void()> task) {
15-
{
16-
unique_lock<mutex> lock(m_taskmutex);
17-
m_tasks.push(task);
18-
}
19-
m_condition.notify_one();
20-
}
13+
2114
void ThreadPool::Worker() {
15+
while(true) {
2216
function<void()> work;
23-
while(true) {
24-
{
25-
unique_lock<mutex> lock(m_taskmutex);
26-
cout << "Worker waiting..." << endl;
27-
while(!stop_workers && m_tasks.empty()) {
28-
m_condition.wait(lock);
29-
}
30-
if (stop_workers) { return; }
31-
cout << "Booting up..." << endl;
32-
work = m_tasks.front();
33-
m_tasks.pop();
34-
}
35-
work();
17+
try {
18+
work = m_taskQueue.pop();
19+
work();
20+
}
21+
catch (const ThreadsafeQueue<WorkType>::QueueFinished&)
22+
{
23+
return;
3624
}
25+
}
3726
}
3827
void ThreadPool::JoinAll() {
3928
for (auto& worker : m_workers) { worker.join(); }
4029
}
30+
void ThreadPool::Finish() {
31+
for (auto& promise : m_promises) promise.get_future().get();
32+
m_promises.clear();
33+
}

0 commit comments

Comments
 (0)