-
Notifications
You must be signed in to change notification settings - Fork 5.7k
overlap send ops and backward ops #10550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b1e5183
6e5635f
b35ea1a
315e44a
00efc4c
eb2e68e
274df85
62af10d
952fa04
e9abc66
147d54b
6debbcd
540b453
fc06222
0aa6f9e
ceefbf3
268e9dc
ad6c014
28596a3
20c24c0
60d827a
6b91d40
8b630ae
5d7c58e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,13 @@ | |
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" | ||
#include <fstream> | ||
#include <utility> | ||
#include "paddle/fluid/framework/details/broadcast_op_handle.h" | ||
#include "paddle/fluid/framework/details/computation_op_handle.h" | ||
#include "paddle/fluid/framework/details/reduce_op_handle.h" | ||
#include "paddle/fluid/framework/details/rpc_op_handle.h" | ||
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" | ||
#include "paddle/fluid/framework/details/send_op_handle.h" | ||
#include "paddle/fluid/framework/op_info.h" | ||
#include "paddle/fluid/framework/scope.h" | ||
|
||
|
@@ -28,6 +29,10 @@ | |
#include <string> | ||
#include <vector> | ||
|
||
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot", | ||
"the ssa graph path only print with GLOG_v=10," | ||
"default /tmp/graph.dot"); | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace details { | ||
|
@@ -79,32 +84,66 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, | |
} | ||
} | ||
|
||
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, | ||
OpDesc *send_op) const { | ||
if (send_op == nullptr) { | ||
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( | ||
const ProgramDesc &program) const { | ||
std::vector<std::string> send_vars; | ||
// since parameters are all in block 0, | ||
// it's enough to only scan send ops in block 0 | ||
for (auto *op : program.Block(0).AllOps()) { | ||
// TODO(Yancey1989): use a graceful method to find send op, | ||
// instead of the the hard code string | ||
if (op->Type() == "send_vars") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we merge There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please follow the comment: #10550 (comment) |
||
auto op_vars = op->InputArgumentNames(); | ||
send_vars.reserve(send_vars.size() + | ||
std::distance(op_vars.begin(), op_vars.end())); | ||
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end()); | ||
} | ||
} | ||
return send_vars; | ||
} | ||
|
||
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( | ||
const ProgramDesc &program) const { | ||
std::vector<std::string> recv_vars; | ||
for (auto *op : program.Block(0).AllOps()) { | ||
// TODO(Yancey1989): use a graceful method to find recv op, | ||
// instead of the hard code string | ||
if (op->Type() == "recv") { | ||
auto op_vars = op->OutputArgumentNames(); | ||
recv_vars.reserve(recv_vars.size() + | ||
std::distance(op_vars.begin(), op_vars.end())); | ||
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end()); | ||
} | ||
} | ||
return recv_vars; | ||
} | ||
|
||
bool MultiDevSSAGraphBuilder::IsDistTrainOp( | ||
const OpDesc &op, const std::vector<std::string> &send_vars, | ||
const std::vector<std::string> &recv_vars) const { | ||
if (send_vars.size() == 0 || recv_vars.size() == 0) { | ||
return false; | ||
} | ||
|
||
/** | ||
* Check any of opvars contains `.block` and in sendvars | ||
*/ | ||
auto checker = [](const std::vector<std::string> &opvars, | ||
const std::vector<std::string> &sendvars) -> bool { | ||
const std::vector<std::string> &rpc_vars) -> bool { | ||
for (auto &var : opvars) { | ||
// a variable name with the suffix `.block` means it's a splited | ||
// variable by (DistributeTranspiler) | ||
// [python/paddle/fluid/transpiler/distribute_transpiler.py] | ||
if (var.find(".block") != std::string::npos && | ||
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { | ||
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
}; | ||
|
||
if (op.Type() == "split" || op.Type() == "split_byref") { | ||
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); | ||
} else if (op.Type() == "concat") { | ||
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); | ||
} | ||
return false; | ||
return checker(op.OutputArgumentNames(), send_vars) || | ||
checker(op.InputArgumentNames(), recv_vars); | ||
} | ||
|
||
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | ||
|
@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( | ||
places_.size()); | ||
|
||
// Find "send" op first for split is in front of send. | ||
OpDesc *send_op = GetSendOpDesc(program); | ||
// find send/recv vars so that we can place the distributed training | ||
// realted op in the place 0 | ||
auto send_vars = FindDistTrainSendVars(program); | ||
auto recv_vars = FindDistTrainRecvVars(program); | ||
|
||
size_t cur_device_id = 0; | ||
std::vector<std::unordered_set<std::string>> var_name_on_devices; | ||
|
@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |
|
||
bool is_forwarding = true; | ||
for (auto *op : program.Block(0).AllOps()) { | ||
if (op->Type() == "send") { | ||
// append send op if program is distributed trainer main program. | ||
if (boost::get<int>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cool. |
||
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == | ||
static_cast<int>(OpRole::kRPC)) { | ||
// append rpc op if program is distributed trainer main program. | ||
// always use the first device | ||
CreateSendOp(&result, *op); | ||
} else if (IsDistTrainOp(*op, send_op)) { | ||
CreateComputationalOps(&result, *op, 1); | ||
CreateRPCOp(&result, *op); | ||
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) { | ||
CreateDistTrainOp(&result, *op); | ||
} else if (IsScaleLossOp(*op)) { | ||
// user can customize loss@grad if not use_default_grad_scale_ | ||
if (strategy_.gradient_scale_ != | ||
|
@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |
AddOutputToLeafOps(&result); | ||
|
||
if (VLOG_IS_ON(10)) { | ||
std::ostringstream sout; | ||
PrintGraphviz(*graph, sout); | ||
VLOG(10) << sout.str(); | ||
std::ofstream fout(FLAGS_ssa_graph_path); | ||
PrintGraphviz(*graph, fout); | ||
} | ||
|
||
return std::unique_ptr<SSAGraph>(graph); | ||
|
@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, | |
CreateOpHandleIOs(result, op, dev_id); | ||
} | ||
|
||
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( | ||
const ProgramDesc &program) const { | ||
for (auto *op : program.Block(0).AllOps()) { | ||
if (op->Type() == "send") { | ||
return op; | ||
} | ||
} | ||
return nullptr; | ||
} | ||
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( | ||
SSAGraph *result, const std::string &og) const { | ||
#ifdef PADDLE_WITH_CUDA | ||
|
@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, | |
return var; | ||
} | ||
|
||
void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, | ||
const OpDesc &op) const { | ||
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, | ||
const std::string &prev_op_name) const { | ||
for (auto &prev_op : result->ops_) { | ||
if (prev_op->Name() == prev_op_name) { | ||
auto *dep_var = new DummyVarHandle(); | ||
prev_op->AddOutput(dep_var); | ||
result->dep_vars_.emplace(dep_var); | ||
op->AddInput(dep_var); | ||
} | ||
} | ||
} | ||
|
||
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, | ||
const OpDesc &op) const { | ||
CreateComputationalOp(result, op, 0); | ||
if (op.Type() == "concat") { | ||
ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can there be 2 fetch_barrier ops? and it connects to the wrong one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the current implementation of |
||
} | ||
} | ||
|
||
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, | ||
const OpDesc &op) const { | ||
auto &p = places_[0]; | ||
auto *s = local_scopes_[0]; | ||
// FIXME(wuyi): send op always copy from GPU 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should add back this FIXME? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will add back this comment in the next PR.. |
||
result->ops_.emplace_back(new SendOpHandle(op, s, p)); | ||
// Create inputs for output on original place and no ssa output | ||
// is created for send op. | ||
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); | ||
|
||
if (op.Type() == "send_barrier") { | ||
ConnectOp(result, result->ops_.back().get(), "send_vars"); | ||
} else if (op.Type() == "recv") { | ||
ConnectOp(result, result->ops_.back().get(), "send_barrier"); | ||
} else if (op.Type() == "fetch_barrier") { | ||
ConnectOp(result, result->ops_.back().get(), "recv"); | ||
} else if (op.Type() == "send_vars") { | ||
// do nothing | ||
} else { | ||
PADDLE_THROW( | ||
"rpc op should be in [" | ||
"send_vars, send_barrier. recv, fetch_barrier]"); | ||
} | ||
|
||
// TODO(Yancey1989): schedule rpc op on different place may | ||
// increate throughput | ||
CreateOpHandleIOs(result, op, 0); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR needs a convergence experiment to make sure it still converge to the same place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will do that.