Skip to content

Commit 4f69eec

Browse files
author
Chris Sullivan
committed
Merging Timer class as well as std::promise
based (blocking) ThreadPool::Finish, and also a better ParallelFor. See merged commit logs for details.
2 parents 7761703 + faa2f04 commit 4f69eec

File tree

6 files changed

+103
-36
lines changed

6 files changed

+103
-36
lines changed

README.md

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,37 @@ A simple std::thread based thread pool with a parallel-for loop implementation.
44

55
Example of how to perform STREAM benchmark scale operation in parallel:
66
```C++
7-
#include "ThreadPool.hh"
8-
9-
static double* a;
10-
static double* b;
11-
12-
using namespace std;
7+
void scale(int i, double* a, double* b) {
8+
a[i]=4*b[i];
9+
}
1310

1411
int main () {
1512

16-
ThreadPool pool(4);
13+
ThreadPool pool(8);
1714

18-
int N = 1e7;
19-
a = (double*)calloc(N,sizeof(double));
20-
b = (double*)calloc(N,sizeof(double));
15+
int N = 1e9;
16+
auto a = (double*)calloc(N,sizeof(double));
17+
auto b = (double*)calloc(N,sizeof(double));
2118
for (int i=0; i<N; i++) { b[i] = i; }
2219

23-
SERIAL_OPERATION(scale, a[i]=4*b[i]);
24-
pool.ParallelFor<scale>(0,N);
2520

26-
cin.get();
21+
// cold start for timing purposes
22+
pool.ParallelFor(0,N,scale,a,b);
23+
24+
int ntrials = 10;
25+
double tperformance = 0.0;
26+
for (int i=0; i<ntrials; i++)
27+
{
28+
Timer timer([&](int elapsed) {
29+
cout << "Trial " << i << ": "<< elapsed*1e-6 << " ms\n";
30+
tperformance+=elapsed;
31+
});
32+
pool.ParallelFor(0,N,scale,a,b);
33+
}
34+
cout << "Average: " << tperformance*1e-6 / ntrials << " ms\n\n";
35+
36+
37+
2738
return 0;
2839
}
2940

default.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ CFLAGS_EXTRA =
6464

6565
# Mandatory arguments to the C++ compiler. These arguments will be
6666
# passed even if CXXFLAGS has been overridden by command-line arguments.
67-
CXXFLAGS_EXTRA = -std=c++11
67+
CXXFLAGS_EXTRA = -std=c++1y
6868

6969
# Mandatory arguments to the linker, before the listing of object
7070
# files. These arguments will be passed even if LDFLAGS has been

example.cc

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,43 @@
11

22
#include "ThreadPool.hh"
3+
#include "Timer.hh"
4+
5+
#include <iostream>
36

4-
static double* a;
5-
static double* b;
67

78
using namespace std;
89

10+
11+
void scale(int i, double* a, double* b) {
12+
a[i]=4*b[i];
13+
}
14+
915
int main () {
1016

11-
ThreadPool pool(4);
17+
ThreadPool pool(8);
1218

13-
int N = 1e7;
14-
a = (double*)calloc(N,sizeof(double));
15-
b = (double*)calloc(N,sizeof(double));
19+
int N = 1e9;
20+
auto a = (double*)calloc(N,sizeof(double));
21+
auto b = (double*)calloc(N,sizeof(double));
1622
for (int i=0; i<N; i++) { b[i] = i; }
1723

18-
SERIAL_OPERATION(scale, a[i]=4*b[i]);
19-
pool.ParallelFor<scale>(0,N);
2024

21-
cin.get();
25+
// cold start for timing purposes
26+
pool.ParallelFor(0,N,scale,a,b);
27+
28+
int ntrials = 10;
29+
double tperformance = 0.0;
30+
for (int i=0; i<ntrials; i++)
31+
{
32+
Timer timer([&](int elapsed){
33+
cout << "Trial " << i << ": "<< elapsed*1e-6 << " ms\n";
34+
tperformance+=elapsed;
35+
});
36+
pool.ParallelFor(0,N,scale,a,b);
37+
}
38+
cout << "Average: " << tperformance*1e-6 / ntrials << " ms\n\n";
39+
40+
41+
2242
return 0;
2343
}

include/ThreadPool.hh

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#ifndef THREADPOOL_H
22
#define THREADPOOL_H
33

4-
#include <thread>
4+
#include <iostream>
55
#include <chrono>
6+
#include <thread>
67
#include <functional>
78
#include <vector>
9+
#include <utility>
10+
#include <future>
811

912
#include "ThreadsafeQueue.hh"
1013

@@ -17,50 +20,57 @@ public:
1720
ThreadPool(uint32_t numthreads);
1821
~ThreadPool();
1922

23+
void Enqueue(function<void()> task);
2024
void Worker();
2125
void JoinAll();
26+
void Finish();
2227

23-
template <typename ClassFunction>
24-
void ParallelFor(uint32_t begin, uint32_t end) {
28+
template <typename T, typename... Params>
29+
void ParallelFor(uint32_t begin, uint32_t end, T SerialFunction, Params&&... params) {
2530

2631
int chunk = (end - begin) / m_nthreads;
2732
for (int i = 0; i < m_nthreads; ++i) {
33+
m_promises.emplace_back();
34+
int mypromise = m_promises.size() - 1;
2835
m_taskQueue.push([=]{
29-
int threadstart = begin + i*chunk;
30-
int threadstop = (i == m_nthreads - 1) ? end : threadstart + chunk;
31-
for (int it = threadstart; it < threadstop; ++it) {
32-
ClassFunction::Serial(it);
33-
}
34-
});
36+
uint32_t threadstart = begin + i*chunk;
37+
uint32_t threadstop = (i == m_nthreads - 1) ? end : threadstart + chunk;
38+
for (uint32_t it = threadstart; it < threadstop; ++it) {
39+
SerialFunction(it, params...);
40+
}
41+
m_promises[mypromise].set_value();
42+
});
3543
}
44+
Finish();
3645
}
37-
3846
template<typename InputIt, typename T>
3947
void ParallelMap(InputIt begin, InputIt end, InputIt outputBegin, const std::function<T(T)>& func)
4048
{
4149
int chunkSize = (end - begin) / m_nthreads;
4250
for (int i = 0; i < m_nthreads; i++) {
51+
m_promises.emplace_back();
52+
int mypromise = m_promises.size() - 1;
4353
m_taskQueue.push([=]{
4454
InputIt threadBegin = begin + i*chunkSize;
4555
InputIt threadOutput = outputBegin + i*chunkSize;
4656
InputIt threadEnd = (i == m_nthreads - 1) ? end : threadBegin + chunkSize;
4757
while (threadBegin != threadEnd) {
4858
*(threadOutput++) = func(*(threadBegin++));
4959
}
60+
m_promises[mypromise].set_value();
5061
});
5162
}
63+
Finish();
5264
}
5365

5466
private:
5567
// threads and task queue
5668
int m_nthreads;
5769
vector<thread> m_workers;
70+
vector<promise<void>> m_promises;
5871
bool m_stopWorkers;
5972

6073
ThreadsafeQueue<WorkType> m_taskQueue;
6174
};
6275

63-
// Macro to define a static class function which can be called via ThreadPool::ParallelFor<T>
64-
#define SERIAL_OPERATION(name, function_kernel) class name { public: static void Serial(const int& i) { function_kernel; } };
65-
6676
#endif /* end of include guard: THREADPOOL_H */

include/Timer.hh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef TIMER_H
2+
#define TIMER_H
3+
#include <functional>
4+
#include <chrono>
5+
6+
using namespace std;
7+
8+
struct Timer {
9+
Timer(function<void(int)> callback)
10+
: callback(callback)
11+
, t0(chrono::high_resolution_clock::now()) { ; }
12+
~Timer(void) {
13+
auto t1 = chrono::high_resolution_clock::now();
14+
auto elapsed = chrono::duration_cast<chrono::nanoseconds>(t1-t0).count();
15+
callback(elapsed);
16+
}
17+
function<void(int)> callback;
18+
chrono::high_resolution_clock::time_point t0;
19+
20+
};
21+
22+
#endif

src/ThreadPool.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,7 @@ void ThreadPool::Worker() {
2727
void ThreadPool::JoinAll() {
2828
for (auto& worker : m_workers) { worker.join(); }
2929
}
30+
void ThreadPool::Finish() {
31+
for (auto& promise : m_promises) promise.get_future().get();
32+
m_promises.clear();
33+
}

0 commit comments

Comments
 (0)