Skip to content

overlap send ops and backward ops #10550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
May 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b1e5183
overlap sendop and backward ops
Yancey0623 May 10, 2018
6e5635f
update
Yancey0623 May 10, 2018
b35ea1a
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap…
Yancey0623 May 14, 2018
315e44a
add fetch_barrier_op
Yancey0623 May 15, 2018
00efc4c
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap…
Yancey0623 May 15, 2018
eb2e68e
revert executor run
Yancey0623 May 15, 2018
274df85
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap…
Yancey0623 May 17, 2018
62af10d
support multiple devices
Yancey0623 May 21, 2018
952fa04
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap…
Yancey0623 May 21, 2018
e9abc66
fix pe
Yancey0623 May 22, 2018
147d54b
update
Yancey0623 May 22, 2018
6debbcd
connect fetch barrier and concat op
Yancey0623 May 23, 2018
540b453
use req_count as atomic type
Yancey0623 May 23, 2018
fc06222
fix async worker
Yancey0623 May 24, 2018
0aa6f9e
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap…
Yancey0623 May 25, 2018
ceefbf3
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap…
Yancey0623 May 27, 2018
268e9dc
polish code
Yancey0623 May 27, 2018
ad6c014
clean up codes
Yancey0623 May 28, 2018
28596a3
add gflag ssa_graph_path
Yancey0623 May 28, 2018
20c24c0
singleton rpc_client
Yancey0623 May 29, 2018
60d827a
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap…
Yancey0623 May 29, 2018
6b91d40
revert variable mutex
Yancey0623 May 29, 2018
8b630ae
fix unit test
Yancey0623 May 29, 2018
5d7c58e
fix code style
Yancey0623 May 29, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR needs a convergence experiment to make sure it still converge to the same place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do that.

cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)

cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
Expand All @@ -26,7 +26,7 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)

cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)

cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
Expand Down
141 changes: 104 additions & 37 deletions paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"

Expand All @@ -28,6 +29,10 @@
#include <string>
#include <vector>

DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot");

namespace paddle {
namespace framework {
namespace details {
Expand Down Expand Up @@ -79,32 +84,66 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
}
}

bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
OpDesc *send_op) const {
if (send_op == nullptr) {
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const ProgramDesc &program) const {
std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto *op : program.Block(0).AllOps()) {
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if (op->Type() == "send_vars") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge send_vars and send op by adding some attributes to control the behaviour?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the comment: #10550 (comment)

auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
}
}
return send_vars;
}

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const {
std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) {
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if (op->Type() == "recv") {
auto op_vars = op->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
}
}
return recv_vars;
}

bool MultiDevSSAGraphBuilder::IsDistTrainOp(
const OpDesc &op, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false;
}

/**
* Check any of opvars contains `.block` and in sendvars
*/
auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &sendvars) -> bool {
const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) {
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
}
}
return false;
};

if (op.Type() == "split" || op.Type() == "split_byref") {
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
}
return false;
return checker(op.OutputArgumentNames(), send_vars) ||
checker(op.InputArgumentNames(), recv_vars);
}

std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
Expand All @@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());

// Find "send" op first for split is in front of send.
OpDesc *send_op = GetSendOpDesc(program);
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);

size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices;
Expand All @@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(

bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
// append send op if program is distributed trainer main program.
if (boost::get<int>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool.

op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateSendOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op);
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ !=
Expand Down Expand Up @@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result);

if (VLOG_IS_ON(10)) {
std::ostringstream sout;
PrintGraphviz(*graph, sout);
VLOG(10) << sout.str();
std::ofstream fout(FLAGS_ssa_graph_path);
PrintGraphviz(*graph, fout);
}

return std::unique_ptr<SSAGraph>(graph);
Expand Down Expand Up @@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
CreateOpHandleIOs(result, op, dev_id);
}

OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
const ProgramDesc &program) const {
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
return op;
}
}
return nullptr;
}
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
Expand Down Expand Up @@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return var;
}

void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
const OpDesc &op) const {
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var);
op->AddInput(dep_var);
}
}
}

void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
CreateComputationalOp(result, op, 0);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can there be 2 fetch_barrier ops? and it connects to the wrong one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the current implementation of dist transpiler, there would be only one fetch_barrier_op, will also add some comments here and maybe add an argument bool connect_all in function ConnectOp would be better.

}
}

void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
auto &p = places_[0];
auto *s = local_scopes_[0];
// FIXME(wuyi): send op always copy from GPU 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add back this FIXME?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add back this comment in the next PR..

result->ops_.emplace_back(new SendOpHandle(op, s, p));
// Create inputs for output on original place and no ssa output
// is created for send op.
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));

if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send_vars");
} else if (op.Type() == "recv") {
ConnectOp(result, result->ops_.back().get(), "send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, result->ops_.back().get(), "recv");
} else if (op.Type() == "send_vars") {
// do nothing
} else {
PADDLE_THROW(
"rpc op should be in ["
"send_vars, send_barrier. recv, fetch_barrier]");
}

// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs(result, op, 0);
}

Expand Down
22 changes: 14 additions & 8 deletions paddle/fluid/framework/details/multi_devices_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {

bool IsScaleLossOp(const OpDesc &op) const;

void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;

/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
bool IsDistTrainOp(const OpDesc &op,
const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;

std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const;

std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;

void ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const;

void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const;
Expand All @@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const;

/**
* Get send op in the global block of program.
* nullptr if not found.
*/
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;

bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"

namespace paddle {
namespace framework {
namespace details {

SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope,
const platform::Place &place)
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope, const platform::Place &place,
const std::string &name)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place) {}
place_(place),
name_(name) {}

void SendOpHandle::RunImpl() {
void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if (in->DebugString() == "dummy") { // HACK
continue;
}
Expand All @@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() {
op_->Run(*tmp_scope, place_);
}

std::string SendOpHandle::Name() const { return "send"; }
std::string RPCOpHandle::Name() const { return name_; }
} // namespace details
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ namespace paddle {
namespace framework {
namespace details {

struct SendOpHandle : public OpHandleBase {
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place);
struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place, const std::string& name);

std::string Name() const override;

Expand All @@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
const std::string name_;
};

} // namespace details
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/op_proto_maker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
.InEnum(
{static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize),
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward),
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/op_proto_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ enum class OpRole {
kForward = 0x0000,
kBackward = 0x0001,
kOptimize = 0x0002,
kRPC = 0x0003,

kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/inference/analysis/data_flow_graph_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {

GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes();
int count = 0;
size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name();
++count;
Expand All @@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
dfg.Build();
GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes_in_DFS();
int count = 0;
size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name();
++count;
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE)
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL)
Expand All @@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE)
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif()
else()
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op)
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif()

op_library(cross_entropy_op DEPS cross_entropy)
Expand Down
Loading