-
Notifications
You must be signed in to change notification settings - Fork 5.7k
overlap send ops and backward ops #10550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
seems |
|
||
class PSDispatcher(object): | ||
""" | ||
DistributedSpliter is the base class for dispatching vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DistributedSpliter
this name is not the same as the class name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
b47a8aa
to
0aa6f9e
Compare
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( | ||
const ProgramDesc &program) const { | ||
std::vector<std::string> send_vars; | ||
for (auto *op : program.Block(0).AllOps()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment to say "since parameters are all in block 0, it's enough to only scan send ops in block 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
const ProgramDesc &program) const { | ||
std::vector<std::string> send_vars; | ||
for (auto *op : program.Block(0).AllOps()) { | ||
if (op->Type() == "send_vars" || op->Type() == "send") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to gradually remove these "send_vars" "send" strings. It's hard to maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a comment here, can we fix this in anohter PR?
const ProgramDesc &program) const { | ||
std::vector<std::string> recv_vars; | ||
for (auto *op : program.Block(0).AllOps()) { | ||
if (op->Type() == "recv" || op->Type() == "send") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why "send" is here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, delete this deprecated send
op.
@@ -4,6 +4,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h | |||
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR needs a convergence experiment to make sure it still converge to the same place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will do that.
return false; | ||
} | ||
|
||
/** | ||
* Check any of opvars contains `.block` and in sendvars | ||
*/ | ||
auto checker = [](const std::vector<std::string> &opvars, | ||
const std::vector<std::string> &sendvars) -> bool { | ||
const std::vector<std::string> &rpc_vars) -> bool { | ||
for (auto &var : opvars) { | ||
if (var.find(".block") != std::string::npos && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the meaning of '.block'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dist transpiler would split the gradient var into multiple blocks, and append a suffix .block
at the end of the var name, I will add some comments here to explain the hard code.
return true; | ||
} | ||
} | ||
return false; | ||
}; | ||
|
||
if (op.Type() == "split" || op.Type() == "split_byref") { | ||
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); | ||
if (op.Type() == "split" || op.Type() == "split_byref" || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to make these codes more maintainable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed this PR again, and we don't need to pay attention to the op type, because in the DistributedTranspiler, we split the var before send
and concat the vars after recv
, so pay attention to the Inputs and Outputs is engouht.
const OpDesc &op) const { | ||
CreateComputationalOp(result, op, 0); | ||
if (op.Type() == "concat") { | ||
ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can there be 2 fetch_barrier ops? and it connects to the wrong one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the current implementation of dist transpiler
, there would be only one fetch_barrier_op
, will also add some comments here and maybe add an argument bool connect_all
in function ConnectOp
would be better.
@@ -257,8 +257,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { | |||
|
|||
auto ch = | |||
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); | |||
|
|||
channels_[ep] = ch; | |||
channels_[key] = ch; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, what's the benefit of having multiple channels per ep?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a testing feature about multi-channel and throughput, I will revert these codes and do much more test in the next PR.
bbccdd0
to
268e9dc
Compare
Pass = 0, Training performance = 13.056212 imgs/s, Train accuracy = 0.034640, Test accuracy = 0.025490
Pass = 1, Training performance = 13.458697 imgs/s, Train accuracy = 0.042609, Test accuracy = 0.029412
Pass = 2, Training performance = 13.474373 imgs/s, Train accuracy = 0.049005, Test accuracy = 0.041176
Pass = 3, Training performance = 13.460915 imgs/s, Train accuracy = 0.053057, Test accuracy = 0.048039
Pass = 4, Training performance = 13.443091 imgs/s, Train accuracy = 0.057180, Test accuracy = 0.058824
Pass = 5, Training performance = 13.433230 imgs/s, Train accuracy = 0.059657, Test accuracy = 0.055882
Pass = 6, Training performance = 13.580382 imgs/s, Train accuracy = 0.061427, Test accuracy = 0.045098
Pass = 7, Training performance = 13.464448 imgs/s, Train accuracy = 0.063039, Test accuracy = 0.054902
Pass = 0, Training performance = 14.861740 imgs/s, Train accuracy = 0.030411, Test accuracy = 0.027451
Pass = 1, Training performance = 15.150844 imgs/s, Train accuracy = 0.034640, Test accuracy = 0.046078
Pass = 2, Training performance = 15.204456 imgs/s, Train accuracy = 0.042338, Test accuracy = 0.037255
Pass = 3, Training performance = 15.216520 imgs/s, Train accuracy = 0.048626, Test accuracy = 0.047059
Pass = 4, Training performance = 15.235357 imgs/s, Train accuracy = 0.052724, Test accuracy = 0.049020
Pass = 5, Training performance = 15.267941 imgs/s, Train accuracy = 0.055998, Test accuracy = 0.062745
Pass = 6, Training performance = 15.212211 imgs/s, Train accuracy = 0.059382, Test accuracy = 0.070588
Pass = 7, Training performance = 15.209697 imgs/s, Train accuracy = 0.061555, Test accuracy = 0.067647
Pass = 8, Training performance = 15.232591 imgs/s, Train accuracy = 0.064455, Test accuracy = 0.060784 |
A problem was found in the process of testing distributed training, that trainer will be hung with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small comments can be fixed now. Others can be left as TODO.
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); | ||
return checker(op.OutputArgumentNames(), send_vars) || | ||
checker(op.InputArgumentNames(), recv_vars); | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
"rpc op should be in [send," | ||
"send_vars, send_barrier. recv, fetch_barrier]"); | ||
} | ||
|
||
// FIXME(wuyi): send op always copy from GPU 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this FIXME as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, and add a TODO to improve the performance.
// Wait input done | ||
for (auto *in : inputs_) { | ||
auto &p = static_cast<VarHandle *>(in)->place_; | ||
if (in->DebugString() == "dummy") { // HACK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need a better solution here soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, I add a TOOD comment here.
paddle/fluid/framework/variable.h
Outdated
@@ -38,6 +39,7 @@ class Variable { | |||
|
|||
template <typename T> | |||
T* GetMutable() { | |||
std::unique_lock<std::mutex> lock(mutex_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a TODO here to make Variable completely thread-safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and | |||
limitations under the License. */ | |||
#pragma once | |||
|
|||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -20,6 +20,9 @@ namespace operators { | |||
|
|||
inline bool NeedSend(const framework::Scope& scope, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this method test if the variable is a parameter on ps server?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May not, an op don't need to care about where the parameter is(transpiler knew much more about this), only consider the variable is initialized or not.
return [op.type for op in trainer.global_block().ops | ||
] + ["split_byref", "send", "concat"] | ||
ops = [op.type for op in trainer.global_block().ops] + [ | ||
"split_byref", "send_vars", "send_barrier", "recv", "recv", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recv duplicated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
class HashName(PSDispatcher): | ||
""" | ||
Hash variable names to servral endpoints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
several
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
eplist = ps_dispatcher.dispatch(splited_vars) | ||
if len(splited_vars) == 1: | ||
orig_varname = splited_vars[0].name | ||
index = find_op_by_output_arg(program.global_block(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does dist_transpiler works well with memory_optimizer_transpiler? memory_optimize_transpiler tend to change var names.
leave a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reminding, we have not test whether distributed_transpiler
works well with memory_optimize_transpiler
, it's a good point.
# type="fetch_barrier", | ||
# inputs={}, | ||
# outputs={"RPCClient": rpc_client_var}, | ||
# attrs={"endpoints": pserver_endpoints}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
for (auto *op : program.Block(0).AllOps()) { | ||
// TODO(Yancey1989): use a graceful method to find send op, | ||
// instead of the the hard code string | ||
if (op->Type() == "send_vars") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we merge send_vars
and send
op by adding some attributes to control the behaviour?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please follow the comment: #10550 (comment)
std::ostringstream sout; | ||
PrintGraphviz(*graph, sout); | ||
VLOG(10) << sout.str(); | ||
std::ofstream fout("/tmp/graph.dot"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use GFLAG to define this output path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/details/rpc_op_handle.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If send_op_handle is not used anymore, can you just rename it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleted the unused send_op_handle.
paddle/fluid/framework/variable.h
Outdated
@@ -38,6 +39,7 @@ class Variable { | |||
|
|||
template <typename T> | |||
T* GetMutable() { | |||
std::unique_lock<std::mutex> lock(mutex_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can't give var a mutex here. Vars are always not thread-safe in any programming language. You must protect your vars where you use it concurrently, but not add it here.
And, adding a mutex cannot protect the var's content actually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can not make var thread-safe completely, only make GetMutable
function thread-safe, I filed an issue to follow make variable and gRPC client thread-safe
.
Related issue #10969
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use a workaround instead of changing Variable implement anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can use GRPCClient
instance as a singleton, or do you have some other ideas?
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and | |||
limitations under the License. */ | |||
#pragma once | |||
|
|||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why #pragma once
twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant codes, deleted it...
#include <future> // NOLINT | ||
#include <ostream> | ||
|
||
#include "paddle/fluid/framework/data_type.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might only need send
and recv
ops and adding attrs to control's it's behavior so that we won't copy these codes around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, but it's hard for the current implement of SSAGraph, it's use op type and input/output argument to determine the dependency, we need to do more things to merge send_op
, send_vars_op
, batch_barrier
, recv_op
, fetch_barrier
. I created an issue to follow this comment.
Related issue: #10968
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do not merge this PR until clean it up. Thanks!
Hi @typhoonzero , fixed the comments above, and make |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM++
@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |||
|
|||
bool is_forwarding = true; | |||
for (auto *op : program.Block(0).AllOps()) { | |||
if (op->Type() == "send") { | |||
// append send op if program is distributed trainer main program. | |||
if (boost::get<int>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool.
auto &p = places_[0]; | ||
auto *s = local_scopes_[0]; | ||
// FIXME(wuyi): send op always copy from GPU 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add back this FIXME?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add back this comment in the next PR..
Fixed #9161
Fixed #10969
Experiment with vgg16 + flowers on P40, 2 pservers + 2 trainers
The performance improves 12% on a single device, 20% on multi devices.