Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h

cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy collective_helper
graph build_strategy bind_threaded_ssa_graph_executor collective_helper
fast_threaded_ssa_graph_executor variable_helper)

cc_library(executor_cache SRCS executor_cache.cc DEPS executor)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ cc_library(scope_buffered_monitor SRCS scope_buffered_monitor.cc DEPS scope prof
cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor scope_buffered_monitor)
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle )
cc_library(bind_threaded_ssa_graph_executor SRCS bind_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle gflags ssa_graph_executor scope simple_threadpool device_context)
cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
Expand Down
316 changes: 316 additions & 0 deletions paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h"
#include <deque>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"

#if defined(PADDLE_WITH_XPU)
namespace paddle {
namespace framework {
namespace details {

static std::atomic<unsigned int> exec_op_count_;
static std::atomic<int> error_state;

BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
: strategy_(strategy),
local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes),
places_(places),
graph_(graph),
prepare_pool_(1),
multi_device_op_pool_(1) {
for (uint32_t i = 0; i < places.size(); i++) {
pool_.emplace_back(std::unique_ptr<::ThreadPool>(new ::ThreadPool(1)));
}
Comment on lines +47 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能否解释一下,每个线程池里1根线程,N个线程池,是出于什么考虑?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一个线程绑定一个XPU设备,因为目前xpu的device_context不是线程安全的。

int index = 0;
for (uint32_t i = 0; i < places.size(); i++) {
int id = BOOST_GET_CONST(platform::XPUPlace, places_[i]).device;
if (place_to_index_.find(id) == place_to_index_.end()) {
place_to_index_[id] = index;
index++;
}
}
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep);
if (dep == 0) {
bootstrap_ops_.emplace_back(op);
}
}
PADDLE_ENFORCE_GT(op_deps_.size(), 0,
platform::errors::PreconditionNotMet(
"The graph doesn't have operators."));
PrepareAtomicOpDeps();
}

static std::vector<OpHandleBase *> get_children(OpHandleBase *op) {
auto &outputs = op->Outputs();
std::vector<OpHandleBase *> ret;
for (auto &output : outputs) {
ret.insert(ret.end(), output->PendingOps().begin(),
output->PendingOps().end());
}
return ret;
}

static std::vector<OpHandleBase *> get_parents(OpHandleBase *op) {
auto &inputs = op->Inputs();
std::vector<OpHandleBase *> ret;
for (auto &input : inputs) {
if (input->GeneratedOp() != nullptr) {
ret.push_back(input->GeneratedOp());
}
}
return ret;
}

FetchResultType BindThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) {
VLOG(3) << "enter BindThreadedSSAGraphExecutor Run";
return RunMainStream(fetch_tensors, return_merged);
}

// use 2 streams to run op. The first stream is main stream and will run
// most op exclude op depending on multi device(e.g., all_reduce, fetch op)
FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
const std::vector<std::string> &fetch_tensors, bool return_merged) {
VLOG(3) << "enter MainStream Run";
std::unique_ptr<std::unordered_map<OpHandleBase *, struct RunningItem>>
op_deps = atomic_op_deps_.get();
PrepareAtomicOpDeps();

error_state = 0;
paddle::framework::FetchResultType fetches;
if (return_merged) {
fetches = FetchList(fetch_tensors.size());
} else {
fetches = FetchUnmergedList(fetch_tensors.size());
}
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<OpHandleBase *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops;
auto ready_ops = std::make_shared<BlockingQueue<OpHandleBase *>>();
exception_.Clear();

InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
&fetch_ops, &ready_fetch_ops, return_merged);
for (auto cur_op : bootstrap_ops_) {
ready_ops->Push(cur_op);
}
for (auto cur_op : ready_fetch_ops) {
ready_ops->Push(cur_op);
}

exec_op_count_ = 0;

platform::XPUPlace cur_place;
std::size_t cur_count = 0;

while (cur_count < op_deps_.size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你的cur_count代表的是执行ready_ops力的op的个数。而ready_ops里是有ready_fetch_ops op的。假如op_deps_的op个数是100个。而ready_fetch_ops的个数是100个。瞬间就能达到跳出while的条件,但是真正的op_deps_里的op并没有执行完。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ready_ops除了有ready_fetch_ops还有bootstrap_ops_,达不到跳出while的条件。

cur_count++;
auto cur_op = ready_ops->Pop();
if (cur_op == nullptr) {
// sleep a while to make sure worker thread quit
sleep(10);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里要等10秒??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赶2.0发版,待此PR合入后再行修改

exec_op_count_ = op_deps_.size();
break;
}
auto dev_ctxes_ = cur_op->DeviceContext();
if (cur_op->IsMultiDeviceTransfer()) {
RunMultiDeviceOpAsync(cur_op, op_deps.get(), ready_ops);
continue;
} else {
cur_place =
BOOST_GET_CONST(platform::XPUPlace, dev_ctxes_.begin()->first);
int cur_index = place_to_index_[cur_place.device];
RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index);
}
}
while (exec_op_count_ < op_deps_.size()) {
}
Comment on lines +154 to +155
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种写法,CPU直接100%了。还有上面的sleep问题。统一考虑加一个wait机制吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赶2.0发版,待此PR合入后再行修改


// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
return fetches;
}

void BindThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FetchResultType *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops, bool return_merged) {
std::unordered_set<std::string> fetch_tensor_set(fetch_tensors.begin(),
fetch_tensors.end());
for (auto &fetch_var_name : fetch_tensor_set) {
for (auto &var_map : graph_->Get<GraphVars>(kGraphVars)) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
(*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin());
}
}
}

for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors.at(i);
auto fetched_var_it = fetched_vars->find(var_name);
PADDLE_ENFORCE_NE(
fetched_var_it, fetched_vars->end(),
platform::errors::PreconditionNotMet(
"Cannot find fetched variable(%s) in current computation graph. "
"Possible reasons are:\n"
" 1. The variable to be fetched is not defined in main program.\n"
" 2. The variable to be fetched is not an input or output of any "
"operator.\n"
" 3. Confirm that you have used the fetch `Variable` format "
"instead of the string literal('%s') in `fetch_list` parameter "
"when using `executor.run` method. In other words, the format of "
"`executor.run(fetch_list=[fetch_var])`(fetch_var is a Variable) "
"is recommended.",
var_name, var_name));

auto &vars = fetched_var_it->second;

ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_, return_merged);
fetch_ops->emplace_back(op);

platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
for (auto &p : places_) {
op->SetDeviceContext(p, pool.Get(p));
}

for (auto *var : vars) {
op->AddInput(var);
}

int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op].dep_num = dep;
(*op_deps)[op].op = op;
if (dep == 0) {
ready_fetch_ops->emplace_back(op);
}
}
}

void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::shared_ptr<BlockingQueue<OpHandleBase *>> ready_ops) {
multi_device_op_pool_.enqueue([=] {
try {
if (error_state == 0 && LIKELY(!strategy_.dry_run_)) {
auto dev_ctxes = op->DeviceContext();
auto &inputs = op->Inputs();
for (auto &input : inputs) {
auto dev_ctxes = input->GeneratedOp()->DeviceContext();
for (auto &item : dev_ctxes) {
((platform::XPUDeviceContext *)(item.second))->Wait();
}
}
op->Run(strategy_.use_device_);
auto &outputs = op->Outputs();
for (auto &output : outputs) {
for (auto &pending_op : output->PendingOps()) {
std::atomic<int> &deps = op_deps->at(pending_op).dep_num;
if (deps.fetch_sub(1) == 1) {
ready_ops->Push(pending_op);
}
}
}
} else if (error_state) {
ready_ops->Push(nullptr);
}
} catch (...) {
error_state = 1;
ready_ops->Push(nullptr);
exception_.Catch(std::current_exception());
}
exec_op_count_++;
});
}

void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::shared_ptr<BlockingQueue<OpHandleBase *>> ready_ops, int index) {
pool_[index]->enqueue([=] {
try {
if (error_state == 0 && LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_device_);
auto &outputs = op->Outputs();
for (auto &output : outputs) {
for (auto &pending_op : output->PendingOps()) {
std::atomic<int> &deps = op_deps->at(pending_op).dep_num;
if (deps.fetch_sub(1) == 1) {
ready_ops->Push(pending_op);
}
}
}
} else if (error_state) {
ready_ops->Push(nullptr);
}
} catch (...) {
error_state = 1;
ready_ops->Push(nullptr);
exception_.Catch(std::current_exception());
}
exec_op_count_++;
});
}

void BindThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
atomic_op_deps_ = prepare_pool_.enqueue([&] {
auto *op_deps = new std::unordered_map<OpHandleBase *, struct RunningItem>;
for (auto &pair : op_deps_) {
(*op_deps)[pair.first].dep_num = pair.second;
(*op_deps)[pair.first].op = pair.first;
}
return std::unique_ptr<
std::unordered_map<OpHandleBase *, struct RunningItem>>(op_deps);
});
}

const ir::Graph &BindThreadedSSAGraphExecutor::Graph() const { return *graph_; }

void BindThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_.ReThrow();
}

} // namespace details
} // namespace framework
} // namespace paddle
#endif
Loading