Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class RequestSend final : public RequestBase {
virtual std::string GetReqName() { return request_->Varname(); }

virtual void Process() {
queue_->Push(std::make_pair(request_->Varname(), request_));
std::string var_name = GetReqName();
VLOG(3) << "RequestSend " << var_name;
queue_->Push(std::make_pair(var_name, request_));

sendrecv::VoidMessage reply;
responder_.Finish(reply, ::grpc::Status::OK, this);
Expand All @@ -106,7 +108,7 @@ class RequestGet final : public RequestBase {
responder_(&ctx_),
scope_(scope),
queue_(queue) {
int method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
}
Expand All @@ -118,6 +120,7 @@ class RequestGet final : public RequestBase {
virtual void Process() {
// proc request.
std::string var_name = request_.varname();
VLOG(3) << "RequestGet " << var_name;
auto* var = scope_->FindVar(var_name);

::grpc::ByteBuffer reply;
Expand Down Expand Up @@ -176,7 +179,7 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply;

std::string var_name = request_->OutVarname();
VLOG(3) << "prefetch var " << var_name;
VLOG(3) << "RequestPrefetch " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
Expand Down Expand Up @@ -307,18 +310,20 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
bool ok = false;

while (true) {
VLOG(3) << "HandleRequest for " << cq_name << " while in";
VLOG(3) << "HandleRequest for " << cq_name << " wait Next";
if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " CompletionQueue shutdown!";
break;
}
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
VLOG(3) << "HandleRequest for " << cq_name << " get Next";

PADDLE_ENFORCE(tag);

if (sync_mode_) {
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
}

RequestBase* base = reinterpret_cast<RequestBase*>(tag);
Expand All @@ -336,9 +341,9 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,

switch (base->Status()) {
case PROCESS: {
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
TryToRegisterNewOne();
base->Process();
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
break;
}
case FINISH: {
Expand Down
72 changes: 46 additions & 26 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,6 @@ static void split(const std::string &str, char sep,
}
}

static void AsyncExecuteBlock(framework::Executor *executor,
framework::ExecutorPrepareContext *prepared,
framework::Scope *scope) {
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
try {
executor->RunPreparedContext(prepared, scope, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
// TODO(qiao) maybe we can remove this
future.wait();
}

static void ParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
Expand Down Expand Up @@ -201,14 +187,40 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
} // while(true)
}

static void AsyncUpdateThread(
const std::string &var_name, const bool &exit_flag,
const std::shared_ptr<detail::ReceivedQueue> &queue,
framework::Executor *executor,
framework::ExecutorPrepareContext *prepared) {
VLOG(3) << "update thread for " << var_name << " started";
while (!exit_flag) {
const detail::ReceivedMessage v = queue->Pop();
auto recv_var_name = v.first;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
auto fs = framework::Async([var_name, &executor, &v, prepared] {
try {
executor->RunPreparedContext(prepared, v.second->GetMutableLocalScope(),
false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
fs.wait();
}
}

void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
framework::ProgramDesc *program) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
std::unordered_map<std::string, std::shared_ptr<detail::ReceivedQueue>>
grad_to_queue;

auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
Expand All @@ -220,6 +232,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id;
grad_to_queue[pieces[0]] = std::make_shared<detail::ReceivedQueue>();
id_to_grad[block_id] = pieces[0];
}
size_t num_blocks = program->Size();
Expand All @@ -238,8 +251,21 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
}

VLOG(3) << "RunAsyncLoop into while";
bool exit_flag = false;

VLOG(3) << "start async optimize threads";
std::vector<std::future<void>> fs;
for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to use a thread pool here to control the size of threads running blocks, for the number of CPU cores may be much less than grad vars, this may reduce the performance.

std::string grad_name = iter->first;
VLOG(3) << "create async update thread for " << grad_name;
fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor,
&grad_to_queue, &grad_to_prepared_ctx]() {
AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name],
executor, grad_to_prepared_ctx[grad_name].get());
}));
}

VLOG(3) << "RunAsyncLoop into while";
while (!exit_flag) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
Expand All @@ -249,13 +275,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
AsyncExecuteBlock(executor, grad_to_prepared_ctx[recv_var_name].get(),
v.second->GetMutableLocalScope());
grad_to_queue[recv_var_name]->Push(v);
}

if (exit_flag) {
Expand Down Expand Up @@ -304,7 +324,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
} else {
RunAsyncLoop(&executor, program, &recv_scope, prefetch_block);
RunAsyncLoop(&executor, program);
}
}

Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/operators/listen_and_serv_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ class ListenAndServOp : public framework::OperatorBase {
framework::BlockDesc* prefetch_block) const;

void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;
framework::ProgramDesc* program) const;

void Stop() override;

Expand Down
4 changes: 3 additions & 1 deletion python/paddle/fluid/layers/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def complete_op(self):
'endpoint': self.endpoint,
'Fanin': self.fan_in,
'OptimizeBlock': current_block,
'PrefetchBlock': empty_block
'PrefetchBlock': empty_block,
'sync_mode': True, # did not support async now in layers
'grad_to_block_id': [""]
})


Expand Down