Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
}
}

std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; }
std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
} // namespace details
} // namespace framework
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

#pragma once

#include <string>
#include <vector>

#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
Expand All @@ -34,6 +37,10 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {

std::string Name() const override;

// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return true; };

protected:
void RunImpl() override;
};
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/details/op_handle_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.

#pragma once
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h"
Expand Down Expand Up @@ -53,6 +55,10 @@ class OpHandleBase {

void AddOutput(VarHandleBase *out);

// If the Op involves data transfer of multiple devices that
// will likely block other computations.
virtual bool IsMultiDeviceTransfer() { return false; }

protected:
virtual void RunImpl() = 0;
};
Expand Down
62 changes: 50 additions & 12 deletions paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,36 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph)
std::unique_ptr<SSAGraph> &&graph, bool allow_op_delay)
: SSAGraphExecutor(std::move(graph)),
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
use_event_(use_event) {}
use_event_(use_event),
running_ops_(0),
allow_op_delay_(allow_op_delay) {}

void ThreadedSSAGraphExecutor::RunDelayedOps(
const std::unordered_set<OpHandleBase *> &delayed_ops) {
for (auto op : delayed_ops) {
op->Run(use_event_);
}
}

FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;

BlockingQueue<VarHandleBase *> ready_vars;

std::unordered_set<OpHandleBase *> ready_ops;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops;
std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
std::unordered_set<VarHandleBase *> delayed_vars;

auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
pending_vars.insert(&var);
Expand Down Expand Up @@ -106,7 +120,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(

auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) {
RunOp(ready_vars, op);
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
delayed_ops.insert(op);
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
ready_vars.Extend(op->outputs_);
continue;
}
running_ops_++;
RunOp(&ready_vars, op);
}
ready_ops.clear();
};
Expand All @@ -118,13 +139,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}

// Step 3. Execution
while (!pending_vars.empty()) {
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
// 1. Run All Ready ops
run_all_ready_ops();

// 2. Find ready variable
bool timeout;
auto cur_ready_vars = ready_vars.PopAll(1000, &timeout);
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);

if (timeout) {
if (exception_) {
Expand All @@ -141,13 +162,29 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
ready_ops.insert(op);
if (delayed_vars.find(ready_var) != delayed_vars.end()) {
blocked_by_delayed_ops.insert(op);
} else {
ready_ops.insert(op);
}
}
}
}
// When there are no other ops to schedule, schedule buffered delayed
// ops and unblock other ops.
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
RunDelayedOps(delayed_ops);
delayed_ops.clear();
for (auto *op : blocked_by_delayed_ops) {
ready_ops.insert(op);
}
blocked_by_delayed_ops.clear();
}
// Keep loop until all vars are ready.
}

PADDLE_ENFORCE(ready_ops.empty());
PADDLE_ENFORCE(delayed_ops.empty());
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
++computation_count_;

auto sync_computation = [&] {
Expand Down Expand Up @@ -182,12 +219,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}

void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) {
auto op_run = [&ready_var_q, op, this] {
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] {
try {
VLOG(10) << op->Name() << " : " << op->DebugString();
op->Run(use_event_);
ready_var_q.Extend(op->outputs_);
running_ops_--;
ready_var_q->Extend(op->outputs_);
} catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
Expand Down
16 changes: 13 additions & 3 deletions paddle/fluid/framework/details/threaded_ssa_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

#pragma once

#include <chrono>
#include <deque>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
Expand Down Expand Up @@ -70,7 +75,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph);
std::unique_ptr<SSAGraph> &&graph,
bool allow_op_delay);

// Run a SSAGraph by a thread pool
// Use topological sort algorithm
Expand All @@ -79,16 +85,20 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() {}

private:
void RunOp(BlockingQueue<VarHandleBase *> &ready_var_q,
void RunOp(BlockingQueue<VarHandleBase *> *ready_var_q,
details::OpHandleBase *op);

void RunDelayedOps(const std::unordered_set<OpHandleBase *> &delayed_ops);

private:
std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_;
const bool use_event_;
std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_;
bool allow_op_delay_;

size_t computation_count_{0};
size_t max_async_computation{100};
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/parallel_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/platform/profiler.h"

#include <string>
#include <vector>
Expand Down Expand Up @@ -47,7 +48,7 @@ ParallelExecutor::ParallelExecutor(
const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope)
const std::string &loss_var_name, Scope *scope, bool allow_op_delay)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;

Expand Down Expand Up @@ -82,8 +83,8 @@ ParallelExecutor::ParallelExecutor(
auto graph = builder.Build(main_program);

member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
num_threads, use_event, member_->local_scopes_, places,
std::move(graph)));
num_threads, use_event, member_->local_scopes_, places, std::move(graph),
allow_op_delay));

// Step 3. Create vars in each scope;
for (auto *scope : member_->local_scopes_) {
Expand Down Expand Up @@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs(

void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
platform::RecordBlock b(0);
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/parallel_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ limitations under the License. */

#pragma once

#include <future>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
Expand All @@ -37,7 +38,8 @@ class ParallelExecutor {
const std::unordered_set<std::string>& params,
const ProgramDesc& startup_program,
const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope);
const std::string& loss_var_name, Scope* scope,
bool allow_op_delay);

void Run(const std::vector<std::string>& fetch_tensors,
const std::string& fetched_var_name = "fetched_var");
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,10 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope) {
Scope *scope, bool allow_op_delay) {
new (&self) ParallelExecutor(num_threads, use_event, places,
params, startup_program, main_program,
loss_var_name, scope);
loss_var_name, scope, allow_op_delay);
})
.def("run", &ParallelExecutor::Run);

Expand Down
16 changes: 13 additions & 3 deletions python/paddle/fluid/parallel_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@


class ParallelExecutor(object):
def __init__(self, loss_name, use_cuda, num_threads=None):
def __init__(self,
loss_name,
use_cuda,
num_threads=None,
allow_op_delay=False):
places = []
if use_cuda:
for i in xrange(core.get_cuda_device_count()):
Expand All @@ -35,7 +39,12 @@ def __init__(self, loss_name, use_cuda, num_threads=None):
places.append(p)

if num_threads is None:
num_threads = min(len(places) * 2, multiprocessing.cpu_count())
if use_cuda:
# Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future.
num_threads = len(places)
else:
min(len(places) * 2, multiprocessing.cpu_count())

startup = framework.default_startup_program()
main = framework.default_main_program()
Expand All @@ -52,7 +61,8 @@ def __init__(self, loss_name, use_cuda, num_threads=None):
startup.desc,
main.desc,
loss_name,
scope)
scope,
allow_op_delay)
self.scope = scope

def run(self, fetch_list):
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ function(py_test_modules TARGET_NAME)
endfunction()

# test time consuming OPs in a separate process for expliot parallism
list(REMOVE_ITEM TEST_OPS test_parallel_executor)
list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list(REMOVE_ITEM TEST_OPS test_dyn_rnn)
list(REMOVE_ITEM TEST_OPS test_mul_op)
Expand Down Expand Up @@ -64,6 +65,7 @@ else()
endif(WITH_FAST_BUNDLE_TEST)

# tests with high overhead
py_test_modules(test_parallel_executor MODULES test_parallel_executor)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR})
py_test_modules(test_train_dyn_rnn MODULES test_dyn_rnn)
py_test_modules(test_mul_op MODULES test_mul_op)
Expand Down
Loading