Skip to content

Batch barrier in send/recv op #7847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 29, 2018
15 changes: 15 additions & 0 deletions paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}

bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);

BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);

sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;

return true;
}

bool RPCClient::Wait() {
if (req_count_ <= 0) {
return true;
Expand Down
24 changes: 24 additions & 0 deletions paddle/operators/detail/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ class ClientBase {
context_->set_deadline(deadline);
}

virtual void Prepare(int64_t time_out) {
context_.reset(new grpc::ClientContext());

std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);

context_->set_deadline(deadline);
}

virtual void Process() = 0;

std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
Expand Down Expand Up @@ -117,6 +126,17 @@ class GetProcessor : public ClientBase {
RequestGetCallBack response_call_back_ = ProcGetResponse;
};

class BatchBarrierProcessor : public ClientBase {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: ClientBase(ch) {}

virtual ~BatchBarrierProcessor() {}

virtual void Process() {}
sendrecv::VoidMessage reply_;
};

class RPCClient {
public:
bool AsyncSendVariable(const std::string& ep,
Expand All @@ -130,6 +150,10 @@ class RPCClient {
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);

bool AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);

bool Wait();

private:
Expand Down
13 changes: 7 additions & 6 deletions paddle/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ void AsyncGRPCServer::RunSyncUpdate() {

cq_send_ = builder.AddCompletionQueue();
cq_get_ = builder.AddCompletionQueue();

server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ << std::endl;

Expand All @@ -141,11 +142,11 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);

t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false,
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));

t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true,
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));

// wait server
Expand Down Expand Up @@ -174,7 +175,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
}
RequestSend* send =
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
VLOG(4) << "create RequestSend status:" << send->Status();
VLOG(4) << "Create RequestSend status:" << send->Status();
}

void AsyncGRPCServer::TryToRegisterNewGetOne() {
Expand All @@ -184,11 +185,11 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status();
VLOG(4) << "Create RequestGet status:" << get->Status();
}

// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
// FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
TryToRegisterNewOne();
Expand Down
3 changes: 1 addition & 2 deletions paddle/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void ShutDown();

protected:
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name,
void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name,
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
Expand Down
3 changes: 3 additions & 0 deletions paddle/operators/detail/sendrecvop_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ namespace paddle {
namespace operators {
namespace detail {

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"

void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg);
Expand Down
55 changes: 32 additions & 23 deletions paddle/operators/recv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ limitations under the License. */
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -95,46 +93,57 @@ class RecvOp : public framework::OperatorBase {
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size();

auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program();
framework::Executor executor(dev_place);

// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
size_t barrier_size = param_count * fan_in;
while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
for (size_t i = 0; i < barrier_size; ++i) {
size_t recv_var_cnt = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
}
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
} else if (grad_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message";
batch_barrier++;
continue;
} else {
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
}
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.FindVar(grad_var_name);
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
PADDLE_THROW("Can not find server side var");
// receive a variable
recv_var_cnt++;
auto it =
std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
} else {
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
}
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;

if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.FindVar(grad_var_name);
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
PADDLE_THROW("Can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
}
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
// TODO(Yancey1989): merge SelectedRows variables here
if (exit_flag) {
break;
}
Expand All @@ -146,7 +155,7 @@ class RecvOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what();
}
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(barrier_size);
rpc_service_->WaitClientGet(recv_var_cnt);
grads_counter_.clear();
} // while(true)
}
Expand Down
12 changes: 10 additions & 2 deletions paddle/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,25 @@ class SendOp : public framework::OperatorBase {
auto ins = Inputs("X");
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> endpoints =
Attr<std::vector<std::string>>("endpoints");

platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i];
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
}
PADDLE_ENFORCE(client_.Wait());

for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
client_.AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(client_.Wait());

for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}

Expand Down