diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index c01de753..28bc92f7 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -55,7 +55,6 @@ class ThreadedEnginePerDevice : public ThreadedEngine { #ifndef _WIN32 pthread_atfork( []() { - Engine::Get()->WaitForAll(); Engine::Get()->Stop(); }, []() { @@ -71,10 +70,10 @@ class ThreadedEnginePerDevice : public ThreadedEngine { #endif } ~ThreadedEnginePerDevice() noexcept(false) { - this->Stop(); + this->StopNoWait(); } - void Stop() override { + void StopNoWait() { SignalQueuesForKill(); gpu_normal_workers_.Clear(); gpu_copy_workers_.Clear(); @@ -82,7 +81,14 @@ class ThreadedEnginePerDevice : public ThreadedEngine { cpu_priority_worker_.reset(nullptr); } + void Stop() override { + if (is_worker_) return; + WaitForAll(); + StopNoWait(); + } + void Start() override { + if (is_worker_) return; gpu_worker_nthreads_ = common::GetNumThreadPerGPU(); cpu_worker_nthreads_ = dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 1); // create CPU task @@ -196,6 +202,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine { ~ThreadWorkerBlock() noexcept(false) {} }; + /*! \brief whether this is a worker thread. */ + static MX_THREAD_LOCAL bool is_worker_; /*! \brief number of concurrent thread cpu worker uses */ int cpu_worker_nthreads_; /*! \brief number of concurrent thread each gpu worker uses */ @@ -219,6 +227,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { bool is_copy_worker, ThreadWorkerBlock *block, std::shared_ptr ready_event) { + this->is_worker_ = true; #if MXNET_USE_CUDA mshadow::Stream *stream; do { @@ -251,6 +260,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { template inline void CPUWorker(Context ctx, ThreadWorkerBlock *block) { + this->is_worker_ = true; auto* task_queue = &(block->task_queue); RunContext run_ctx{ctx, nullptr}; // execute task @@ -303,5 +313,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine { Engine *CreateThreadedEnginePerDevice() { return new ThreadedEnginePerDevice(); } + +MX_THREAD_LOCAL bool ThreadedEnginePerDevice::is_worker_ = false; + } // namespace engine } // namespace mxnet