Skip to content

Commit 7c89fbf

Browse files
author
tnagler
committed
adapt functions
1 parent 3e4dd9a commit 7c89fbf

File tree

5 files changed

+134
-90
lines changed

5 files changed

+134
-90
lines changed

inst/include/RcppThread/ThreadPool.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ class ThreadPool
3636
ThreadPool& operator=(const ThreadPool&) = delete;
3737
ThreadPool& operator=(ThreadPool&& other) = delete;
3838

39+
//! @brief returns a reference to the global thread pool instance.
40+
static ThreadPool& globalInstance()
41+
{
42+
static ThreadPool instance_;
43+
return instance_;
44+
}
45+
3946
template<class F, class... Args>
4047
void push(F&& f, Args&&... args);
4148

inst/include/RcppThread/parallelFor.hpp

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@ namespace RcppThread {
1313

1414
//! computes an index-based for loop in parallel batches.
1515
//! @param begin first index of the loop.
16-
//! @param size the loop runs in the range `[begin, begin + size)`.
16+
//! @param end the loop runs in the range `[begin, end)`.
1717
//! @param f a function (the 'loop body').
18-
//! @param nThreads the number of threads to use; the default uses the number
19-
//! of cores in the machine; if `nThreads = 0`, all work will be done in the
20-
//! main thread.
18+
//! @param nThreads deprecated; loop is run on global thread pool.
2119
//! @param nBatches the number of batches to create; the default (0)
2220
//! triggers a heuristic to automatically determine the number of batches.
2321
//! @details Consider the following code:
@@ -33,29 +31,45 @@ namespace RcppThread {
3331
//! x[i] = i;
3432
//! });
3533
//! ```
36-
//! The function sets up a `ThreadPool` object to do the scheduling. If you
37-
//! want to run multiple parallel for loops, consider creating a `ThreadPool`
38-
//! yourself and using `ThreadPool::forEach()`.
34+
//! The function dispatches to a global thread pool, so it can safely be nested
35+
//! or called multiple times with almost no overhead.
3936
//!
4037
//! **Caution**: if the iterations are not independent from another,
4138
//! the tasks need to be synchronized manually (e.g., using mutexes).
4239
template<class F>
43-
inline void parallelFor(ptrdiff_t begin, ptrdiff_t size, F&& f,
44-
size_t nThreads = std::thread::hardware_concurrency(),
45-
size_t nBatches = 0)
40+
inline void
41+
parallelFor(int begin,
42+
int end,
43+
F&& f,
44+
size_t nThreads = std::thread::hardware_concurrency(),
45+
size_t nBatches = 0)
4646
{
47-
ThreadPool pool(nThreads);
48-
pool.parallelFor(begin, size, f, nBatches);
49-
pool.join();
47+
if (end < begin)
48+
throw std::runtime_error("can only run forward loops");
49+
if (end == begin)
50+
return;
51+
52+
nThreads = std::thread::hardware_concurrency();
53+
auto batches = createBatches(begin, end - begin, nThreads, nBatches);
54+
tpool::FinishLine finishLine{ batches.size() };
55+
auto doBatch = [f, &finishLine](const Batch& b) {
56+
for (ptrdiff_t i = b.begin; i < b.end; i++)
57+
f(i);
58+
finishLine.cross();
59+
};
60+
for (const auto& batch : batches)
61+
ThreadPool::globalInstance().push(doBatch, batch);
62+
finishLine.wait();
5063
}
5164

5265
//! computes a range-based for loop in parallel batches.
5366
//! @param items an object allowing for `items.size()` and whose elements
5467
//! are accessed by the `[]` operator.
5568
//! @param f a function (the 'loop body').
56-
//! @param nThreads the number of threads to use; the default uses the number
57-
//! of cores in the machine; if `nThreads = 0`, all work will be done in the
58-
//! main thread.
69+
//! @param nThreads the number of threads to use; the default uses the
70+
//! number
71+
//! of cores in the machine; if `nThreads = 0`, all work will be done in
72+
//! the main thread.
5973
//! @param nBatches the number of batches to create; the default (0)
6074
//! triggers a heuristic to automatically determine the number of batches.
6175
//! @details Consider the following code:
@@ -71,21 +85,23 @@ inline void parallelFor(ptrdiff_t begin, ptrdiff_t size, F&& f,
7185
//! xx *= 2;
7286
//! });
7387
//! ```
74-
//! The function sets up a `ThreadPool` object to do the scheduling. If you
75-
//! want to run multiple parallel for loops, consider creating a `ThreadPool`
76-
//! yourself and using `ThreadPool::forEach()`.
88+
//! The function dispatches to a global thread pool, so it can safely be nested
89+
//! or called multiple times with almost no overhead.
7790
//!
7891
//! **Caution**: if the iterations are not independent from another,
7992
//! the tasks need to be synchronized manually (e.g., using mutexes).
8093
template<class I, class F>
81-
inline void parallelForEach(I& items, F&& f,
82-
size_t nThreads = std::thread::hardware_concurrency(),
83-
size_t nBatches = 0)
94+
inline void
95+
parallelForEach(I& items,
96+
F&& f,
97+
size_t nThreads = std::thread::hardware_concurrency(),
98+
size_t nBatches = 0)
8499
{
85-
ThreadPool pool(nThreads);
86-
pool.parallelForEach(items, f, nBatches);
87-
pool.join();
100+
// loop ranges ranges indicate iterator offset
101+
const auto begin_it = std::begin(items);
102+
const auto end_it = std::end(items);
103+
auto size = std::distance(begin_it, end_it);
104+
parallelFor(
105+
0, size, [f, &items, &begin_it](int i) { f(*(begin_it + i)); });
88106
}
89-
90-
91107
}

inst/include/RcppThread/tpool.hpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ class FinishLine
5454
//! indicates that a runner has crossed the finish line.
5555
void cross() noexcept
5656
{
57-
if (--runners_ <= 0)
57+
if (--runners_ <= 0) {
5858
cv_.notify_all();
59+
}
5960
}
6061

6162
//! waits for all active runners to cross the finish line.
@@ -188,18 +189,24 @@ class TaskQueue
188189
//! currently locked; enlarges the queue if full.
189190
bool try_push(Task&& task)
190191
{
192+
// must hold lock in case there are multiple producers, abort if already
193+
// taken, so we can check out next queue
194+
std::unique_lock<std::mutex> lk(mutex_, std::try_to_lock);
195+
if (!lk)
196+
return false;
197+
191198
auto b = bottom_.load(m_relaxed);
192199
auto t = top_.load(m_acquire);
193200
RingBuffer<Task>* buf_ptr = buffer_.load(m_relaxed);
194201

195202
if (buf_ptr->capacity() < (b - t) + 1) {
196-
// capacity reached, create buffer with double size
197203
old_buffers_.emplace_back(
198204
exchange(buf_ptr, buf_ptr->enlarge(b, t)));
199205
buffer_.store(buf_ptr, m_relaxed);
200206
}
201207

202208
buf_ptr->store(b, std::move(task));
209+
203210
std::atomic_thread_fence(m_release);
204211
bottom_.store(b + 1, m_relaxed);
205212

@@ -209,10 +216,6 @@ class TaskQueue
209216
//! pops a task from the top of the queue; returns false if lost race.
210217
bool try_pop(Task& task)
211218
{
212-
std::unique_lock<std::mutex> lk(mutex_, std::try_to_lock);
213-
if (!lk)
214-
return false;
215-
216219
auto t = top_.load(m_acquire);
217220
std::atomic_thread_fence(m_seq_cst);
218221
auto b = bottom_.load(m_acquire);

tests/tests.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,21 @@ void testThreadPoolNestedParallelFor()
162162
x[i][j] *= 2;
163163
});
164164
});
165-
pool.join();
165+
pool.wait();
166166

167167
size_t count_wrong = 0;
168168
for (auto xx : x) {
169169
for (auto xxx : xx)
170170
count_wrong += xxx != 2;
171171
}
172-
if (count_wrong > 0)
172+
if (count_wrong > 0) {
173+
for (auto xx : x) {
174+
for (auto xxx : xx)
175+
std::cout << xxx;
176+
std::cout << std::endl;
177+
}
173178
throw std::runtime_error("nested parallelFor gives wrong result");
179+
}
174180
}
175181

176182
// [[Rcpp::export]]
@@ -212,15 +218,21 @@ void testThreadPoolNestedParallelForEach()
212218
xxx *= 2;
213219
});
214220
});
215-
pool.join();
221+
pool.wait();
216222

217223
size_t count_wrong = 0;
218224
for (auto xx : x) {
219225
for (auto xxx : xx)
220226
count_wrong += xxx != 2;
221227
}
222-
if (count_wrong > 0)
223-
throw std::runtime_error("nested parallelForEach gives wrong result");
228+
if (count_wrong > 0) {
229+
for (auto xx : x) {
230+
for (auto xxx : xx)
231+
std::cout << xxx;
232+
std::cout << std::endl;
233+
}
234+
throw std::runtime_error("nested parallelFor gives wrong result");
235+
}
224236
}
225237

226238
// [[Rcpp::export]]
@@ -293,8 +305,14 @@ void testNestedParallelFor()
293305
for (auto xxx : xx)
294306
count_wrong += xxx != 2;
295307
}
296-
if (count_wrong > 0)
308+
if (count_wrong > 0) {
309+
for (auto xx : x) {
310+
for (auto xxx : xx)
311+
std::cout << xxx;
312+
std::cout << std::endl;
313+
}
297314
throw std::runtime_error("nested parallelFor gives wrong result");
315+
}
298316
}
299317

300318
// [[Rcpp::export]]
@@ -391,7 +409,7 @@ void testProgressCounter()
391409
{
392410
RcppThread::ProgressCounter cntr(20, 1);
393411
RcppThread::parallelFor(0, 20, [&] (int i) {
394-
std::this_thread::sleep_for(std::chrono::milliseconds(100));
412+
std::this_thread::sleep_for(std::chrono::milliseconds(200));
395413
cntr++;
396414
});
397415
}
@@ -402,7 +420,7 @@ void testProgressBar()
402420
// 20 iterations in loop, update progress every 1 sec
403421
RcppThread::ProgressBar bar(20, 1);
404422
RcppThread::parallelFor(0, 20, [&] (int i) {
405-
std::this_thread::sleep_for(std::chrono::milliseconds(100));
423+
std::this_thread::sleep_for(std::chrono::milliseconds(200));
406424
++bar;
407425
});
408426
}

tests/testthat/tests.R

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -34,52 +34,52 @@ test_that("parallelFor works", {
3434
expect_silent(testThreadPoolParallelFor())
3535
})
3636

37-
# test_that("nested parallelFor works", {
38-
# expect_silent(testThreadPoolNestedParallelFor())
39-
# })
40-
41-
# test_that("parallelForEach works", {
42-
# expect_silent(testThreadPoolParallelForEach())
43-
# })
44-
#
45-
# test_that("nested parallelForEach works", {
46-
# expect_silent(testThreadPoolNestedParallelForEach())
47-
# })
48-
#
49-
# test_that("works single threaded", {
50-
# expect_silent(testThreadPoolSingleThreaded())
51-
# })
52-
#
53-
# test_that("destructible without join", {
54-
# expect_silent(testThreadPoolDestructWOJoin())
55-
# })
56-
#
57-
#
58-
# ## -------------------------------------------------------
59-
# context("Parallel for functions")
60-
# test_that("parallelFor works", {
61-
# expect_silent(testParallelFor())
62-
# })
63-
#
64-
# test_that("nested parallelFor works", {
65-
# expect_silent(testNestedParallelFor())
66-
# })
67-
#
68-
# test_that("parallelForEach works", {
69-
# expect_silent(testParallelForEach())
70-
# })
71-
#
72-
# test_that("nested parallelForEach works", {
73-
# expect_silent(testNestedParallelForEach())
74-
# })
75-
#
76-
#
77-
# ## ------------------------------------------------------
78-
# context("Progress tracking")
79-
# test_that("ProgressCounter works", {
80-
# expect_output(testProgressCounter(), "100% \\(done\\)")
81-
# })
82-
#
83-
# test_that("ProgressBar works", {
84-
# expect_output(testProgressBar(),"100% \\(done\\)")
85-
# })
37+
test_that("nested parallelFor works", {
38+
expect_silent(testThreadPoolNestedParallelFor())
39+
})
40+
41+
test_that("parallelForEach works", {
42+
expect_silent(testThreadPoolParallelForEach())
43+
})
44+
45+
test_that("nested parallelForEach works", {
46+
expect_silent(testThreadPoolNestedParallelForEach())
47+
})
48+
49+
test_that("works single threaded", {
50+
expect_silent(testThreadPoolSingleThreaded())
51+
})
52+
53+
test_that("destructible without join", {
54+
expect_silent(testThreadPoolDestructWOJoin())
55+
})
56+
57+
58+
## -------------------------------------------------------
59+
context("Parallel for functions")
60+
test_that("parallelFor works", {
61+
expect_silent(testParallelFor())
62+
})
63+
64+
test_that("nested parallelFor works", {
65+
expect_silent(testNestedParallelFor())
66+
})
67+
68+
test_that("parallelForEach works", {
69+
expect_silent(testParallelForEach())
70+
})
71+
72+
test_that("nested parallelForEach works", {
73+
expect_silent(testNestedParallelForEach())
74+
})
75+
76+
77+
## ------------------------------------------------------
78+
context("Progress tracking")
79+
test_that("ProgressCounter works", {
80+
expect_output(testProgressCounter(), "100% \\(done\\)")
81+
})
82+
83+
test_that("ProgressBar works", {
84+
expect_output(testProgressBar(),"100% \\(done\\)")
85+
})

0 commit comments

Comments
 (0)