Skip to content

Commit

Permalink
Fix weird hang bug due to cuInit sometimes calls fork (#8790)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored and szha committed Nov 23, 2017
1 parent 92d848f commit 7b40c03
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/engine/threaded_engine_perdevice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
#ifndef _WIN32
pthread_atfork(
[]() {
Engine::Get()->WaitForAll();
Engine::Get()->Stop();
},
[]() {
Expand All @@ -71,18 +70,25 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
#endif
}
~ThreadedEnginePerDevice() noexcept(false) {
this->Stop();
this->StopNoWait();
}

void Stop() override {
void StopNoWait() {
SignalQueuesForKill();
gpu_normal_workers_.Clear();
gpu_copy_workers_.Clear();
cpu_normal_workers_.Clear();
cpu_priority_worker_.reset(nullptr);
}

void Stop() override {
if (is_worker_) return;
WaitForAll();
StopNoWait();
}

void Start() override {
if (is_worker_) return;
gpu_worker_nthreads_ = common::GetNumThreadPerGPU();
cpu_worker_nthreads_ = dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 1);
// create CPU task
Expand Down Expand Up @@ -196,6 +202,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
~ThreadWorkerBlock() noexcept(false) {}
};

/*! \brief whether this is a worker thread. */
static MX_THREAD_LOCAL bool is_worker_;
/*! \brief number of concurrent thread cpu worker uses */
int cpu_worker_nthreads_;
/*! \brief number of concurrent thread each gpu worker uses */
Expand All @@ -219,6 +227,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
bool is_copy_worker,
ThreadWorkerBlock<type> *block,
std::shared_ptr<ThreadPool::SimpleEvent> ready_event) {
this->is_worker_ = true;
#if MXNET_USE_CUDA
mshadow::Stream<gpu> *stream;
do {
Expand Down Expand Up @@ -251,6 +260,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
template<dmlc::ConcurrentQueueType type>
inline void CPUWorker(Context ctx,
ThreadWorkerBlock<type> *block) {
this->is_worker_ = true;
auto* task_queue = &(block->task_queue);
RunContext run_ctx{ctx, nullptr};
// execute task
Expand Down Expand Up @@ -303,5 +313,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
Engine *CreateThreadedEnginePerDevice() {
return new ThreadedEnginePerDevice();
}

MX_THREAD_LOCAL bool ThreadedEnginePerDevice::is_worker_ = false;

} // namespace engine
} // namespace mxnet

0 comments on commit 7b40c03

Please sign in to comment.