Skip to content

Commit a803548

Browse files
author
Chris Sullivan
committed
Added blocking ThreadPool::Finish() which waits until workers finish the task queue.
Added a scoped Timer class for performance measurements. Also added variadic template ParallelFor function which accepts any callable type and forwards the parameter pack to the callable type. At the moment the variadic parameter pack is captured by value, capture by reference would be much nicer, but unfortunately C++14 lambdas can't yet capture by move or forward semantics. I understand that this can possibly be handled by creating a wraper class that holds something like a std::ref, but I haven't dug into it much yet, since the parameters only are copied once for each task, and so the overhead is not too large. But something to add to the TODO list.
1 parent 00ea301 commit a803548

File tree

4 files changed

+80
-12
lines changed

4 files changed

+80
-12
lines changed

example.cc

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11

22
#include "ThreadPool.hh"
3+
#include "Timer.hh"
4+
5+
#include <iostream>
36

47
static double* a;
58
static double* b;
@@ -13,17 +16,41 @@ void raw_scale(int i, double* a, double* b) {
1316

1417
int main () {
1518

16-
ThreadPool pool(4);
19+
ThreadPool pool(8);
1720

18-
int N = 1e7;
21+
int N = 1e9;
1922
a = (double*)calloc(N,sizeof(double));
2023
b = (double*)calloc(N,sizeof(double));
2124
for (int i=0; i<N; i++) { b[i] = i; }
2225

2326
SERIAL_OPERATION(static_scale, a[i]=4*b[i]);
24-
pool.ParallelFor<static_scale>(0,N);
25-
26-
pool.ParallelFor(0,N,raw_scale,a,b);
27+
{
28+
// This is a cold start loop, it is not timed
29+
pool.ParallelFor<static_scale>(0,N);
30+
}
31+
32+
int ntrials = 10;
33+
double tperformance = 0.0;
34+
for (int i=0; i<ntrials; i++)
35+
{
36+
Timer timer([&](int elapsed){
37+
cout << "Trial " << i << ": "<< elapsed*1e-6 << " ms\n";
38+
tperformance+=elapsed;
39+
});
40+
pool.ParallelFor<static_scale>(0,N);
41+
}
42+
cout << "Average: " << tperformance*1e-6 / ntrials << " ms\n\n";
43+
44+
tperformance = 0.0;
45+
for (int i=0; i<ntrials; i++)
46+
{
47+
Timer timer([&](int elapsed){
48+
cout << "Trial " << i << ": "<< elapsed*1e-6 << " ms\n";
49+
tperformance+=elapsed;
50+
});
51+
pool.ParallelFor(0,N,raw_scale,a,b);
52+
}
53+
cout << "Average: " << tperformance*1e-6 / ntrials << " ms\n\n";
2754

2855

2956
cin.get();

include/ThreadPool.hh

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <mutex>
1010
#include <condition_variable>
1111
#include <utility>
12+
#include <atomic>
13+
#include <future>
1214

1315
using namespace std;
1416

@@ -26,42 +28,57 @@ public:
2628

2729
int chunk = (end - begin) / m_nthreads;
2830
for (int i = 0; i < m_nthreads; ++i) {
31+
m_promises.emplace_back();
32+
int mypromise = m_promises.size() - 1;
2933
Enqueue([=]{
3034
int threadstart = begin + i*chunk;
3135
int threadstop = (i == m_nthreads - 1) ? end : threadstart + chunk;
3236
for (int it = threadstart; it < threadstop; ++it) {
33-
ClassFunction::Serial(it);
37+
ClassFunction::SerialFunction(it);
3438
}
39+
m_promises[mypromise].set_value();
3540
});
3641
}
42+
Finish();
3743
}
3844
template <typename T, typename... Params>
39-
void ParallelFor(uint32_t begin, uint32_t end, T func, Params&&... params) {
45+
void ParallelFor(uint32_t begin, uint32_t end, T SerialFunction, Params&&... params) {
4046

4147
int chunk = (end - begin) / m_nthreads;
4248
for (int i = 0; i < m_nthreads; ++i) {
49+
m_promises.emplace_back();
50+
int mypromise = m_promises.size() - 1;
4351
Enqueue([=]{
4452
uint32_t threadstart = begin + i*chunk;
4553
uint32_t threadstop = (i == m_nthreads - 1) ? end : threadstart + chunk;
4654
for (uint32_t it = threadstart; it < threadstop; ++it) {
47-
func(it, params...);
55+
SerialFunction(it, params...);
4856
}
57+
m_promises[mypromise].set_value();
4958
});
5059
}
60+
Finish();
61+
}
62+
63+
void Finish() {
64+
for (auto& promise : m_promises) promise.get_future().get();
65+
m_promises.clear();
5166
}
5267

5368
private:
5469
// threads and task queue
5570
int m_nthreads;
5671
vector<thread> m_workers;
5772
queue<function<void()>> m_tasks;
73+
vector<promise<void>> m_promises;
5874

5975
// thread synchronization members
6076
bool stop_workers;
6177
condition_variable m_condition;
6278
mutable mutex m_taskmutex;
79+
atomic<int> m_nrunning;
6380

6481
};
6582

6683
// Macro to define a static class function which can be called via ThreadPool::ParallelFor<T>
67-
#define SERIAL_OPERATION(name, function_kernal) class name { public: static void Serial(const int& i) { function_kernal; } };
84+
#define SERIAL_OPERATION(name, function_kernal) class name { public: static void SerialFunction(const int& i) { function_kernal; } };

include/Timer.hh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef TIMER_H
2+
#define TIMER_H
3+
#include <functional>
4+
#include <chrono>
5+
6+
using namespace std;
7+
8+
struct Timer {
9+
Timer(function<void(int)> callback)
10+
: callback(callback)
11+
, t0(chrono::high_resolution_clock::now()) { ; }
12+
~Timer(void) {
13+
auto t1 = chrono::high_resolution_clock::now();
14+
auto elapsed = chrono::duration_cast<chrono::nanoseconds>(t1-t0).count();
15+
callback(elapsed);
16+
}
17+
function<void(int)> callback;
18+
chrono::high_resolution_clock::time_point t0;
19+
20+
};
21+
22+
#endif

src/ThreadPool.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "ThreadPool.hh"
2-
ThreadPool::ThreadPool(uint32_t numthreads) : m_nthreads(numthreads), stop_workers(false) {
2+
ThreadPool::ThreadPool(uint32_t numthreads) : m_nthreads(numthreads), stop_workers(false),
3+
m_nrunning(0) {
34

45
for (uint32_t i=0; i<numthreads;i++) {
56
m_workers.emplace_back(&ThreadPool::Worker, this);
@@ -11,6 +12,7 @@ ThreadPool::~ThreadPool() {
1112
m_condition.notify_all();
1213
JoinAll();
1314
}
15+
// TODO: add futures and promises for timing purposes
1416
void ThreadPool::Enqueue(function<void()> task) {
1517
{
1618
unique_lock<mutex> lock(m_taskmutex);
@@ -23,12 +25,12 @@ void ThreadPool::Worker() {
2325
while(true) {
2426
{
2527
unique_lock<mutex> lock(m_taskmutex);
26-
cout << "Worker waiting..." << endl;
28+
//cout << "Worker waiting..." << endl;
2729
while(!stop_workers && m_tasks.empty()) {
2830
m_condition.wait(lock);
2931
}
3032
if (stop_workers) { return; }
31-
cout << "Booting up..." << endl;
33+
//cout << "Booting up..." << endl;
3234
work = m_tasks.front();
3335
m_tasks.pop();
3436
}

0 commit comments

Comments
 (0)