forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy paththread_pool.cpp
165 lines (145 loc) · 4.26 KB
/
thread_pool.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include <c10/core/thread_pool.h>
#include <c10/util/Logging.h>
#if !defined(__powerpc__) && !defined(__s390x__)
#include <cpuinfo.h>
#endif
namespace c10 {
size_t TaskThreadPoolBase::defaultNumThreads() {
size_t num_threads = 0;
#if !defined(__powerpc__) && !defined(__s390x__)
if (cpuinfo_initialize()) {
num_threads = cpuinfo_get_processors_count();
if (num_threads > 0) {
return num_threads;
}
}
#endif
num_threads = std::thread::hardware_concurrency();
if (num_threads == 0) {
num_threads = 1;
}
return num_threads;
}
ThreadPool::ThreadPool(
int pool_size,
int numa_node_id,
const std::function<void()>& init_thread)
: threads_(pool_size < 0 ? defaultNumThreads() : pool_size),
running_(true),
complete_(true),
available_(threads_.size()),
total_(threads_.size()),
numa_node_id_(numa_node_id) {
for (std::size_t i = 0; i < threads_.size(); ++i) {
threads_[i] = std::thread([this, i, init_thread]() {
if (init_thread) {
init_thread();
}
this->main_loop(i);
});
}
}
ThreadPool::~ThreadPool() {
// Set running flag to false then notify all threads.
{
std::unique_lock<std::mutex> lock(mutex_);
running_ = false;
condition_.notify_all();
}
for (auto& t : threads_) {
try {
t.join();
} catch (const std::exception&) {
}
}
}
size_t ThreadPool::size() const {
return threads_.size();
}
size_t ThreadPool::numAvailable() const {
std::unique_lock<std::mutex> lock(mutex_);
return available_;
}
bool ThreadPool::inThreadPool() const {
for (auto& thread : threads_) {
if (thread.get_id() == std::this_thread::get_id()) {
return true;
}
}
return false;
}
void ThreadPool::run(std::function<void()> func) {
if (threads_.empty()) {
throw std::runtime_error("No threads to run a task");
}
std::unique_lock<std::mutex> lock(mutex_);
// Set task and signal condition variable so that a worker thread will
// wake up and use the task.
tasks_.emplace(std::move(func));
complete_ = false;
condition_.notify_one();
}
void ThreadPool::waitWorkComplete() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [&]() { return complete_; });
}
void ThreadPool::main_loop(std::size_t index) {
std::unique_lock<std::mutex> lock(mutex_);
while (running_) {
// Wait on condition variable while the task is empty and
// the pool is still running.
condition_.wait(lock, [&]() { return !tasks_.empty() || !running_; });
// If pool is no longer running, break out of loop.
if (!running_) {
break;
}
// Copy task locally and remove from the queue. This is
// done within its own scope so that the task object is
// destructed immediately after running the task. This is
// useful in the event that the function contains
// shared_ptr arguments bound via bind.
{
task_element_t tasks = std::move(tasks_.front());
tasks_.pop();
// Decrement count, indicating thread is no longer available.
--available_;
lock.unlock();
// Run the task.
try {
if (tasks.run_with_id) {
tasks.with_id(index);
} else {
tasks.no_id();
}
} catch (const std::exception& e) {
LOG(ERROR) << "Exception in thread pool task: " << e.what();
} catch (...) {
LOG(ERROR) << "Exception in thread pool task: unknown";
}
// Destruct tasks before taking the lock. As tasks
// are user provided std::function, they can run
// arbitrary code during destruction, including code
// that can reentrantly call into ThreadPool (which would
// cause a deadlock if we were holding the lock).
}
// Update status of empty, maybe
// Need to recover the lock first
lock.lock();
// Increment count, indicating thread is available.
++available_;
if (tasks_.empty() && available_ == total_) {
complete_ = true;
completed_.notify_one();
}
// Deliberately hold the lock on the backedge, so this thread has an
// opportunity to acquire a new task before another thread acquires
// the lock.
} // while running_
}
C10_DEFINE_SHARED_REGISTRY(
ThreadPoolRegistry,
TaskThreadPoolBase,
int,
int,
bool);
} // namespace c10