Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ endif()
cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog)

cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog ${BRPC_DEPS})

Expand All @@ -29,8 +29,8 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

add_subdirectory(test)
endif()
1 change: 0 additions & 1 deletion paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ void FleetExecutor::Run(const std::string& carrier_id) {
// Set current running carrier
if (*GlobalVal<std::string>::Get() != carrier_id) {
GlobalVal<std::string>::Set(new std::string(carrier_id));
// TODO(liyurui): Move barrier to service
GlobalVal<MessageBus>::Get()->Barrier();
}
carrier->Start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ message InterceptorMessage {

message InterceptorResponse { optional bool rst = 1 [ default = false ]; }

service TheInterceptorMessageService {
rpc InterceptorMessageService(InterceptorMessage)
service MessageService {
rpc ReceiveInterceptorMessage(InterceptorMessage)
returns (InterceptorResponse);
rpc IncreaseBarrierCount(InterceptorMessage) returns (InterceptorResponse);
}
32 changes: 14 additions & 18 deletions paddle/fluid/distributed/fleet_executor/message_bus.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,9 @@ void MessageBus::Barrier() {

bool MessageBus::DispatchMsgToCarrier(
const InterceptorMessage& interceptor_message) {
if (interceptor_message.ctrl_message()) {
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
// for barrier
IncreaseBarrierCount();
return true;
} else {
const std::string& carrier_id = *GlobalVal<std::string>::Get();
return GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(interceptor_message);
}
const std::string& carrier_id = *GlobalVal<std::string>::Get();
return GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(interceptor_message);
}

void MessageBus::ListenPort() {
Expand All @@ -185,10 +176,9 @@ void MessageBus::ListenPort() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
// function keep listen the port and handle the message
PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service_,
brpc::SERVER_DOESNT_OWN_SERVICE),
0, platform::errors::Unavailable(
"Message bus: init brpc service error."));
PADDLE_ENFORCE_EQ(
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0,
platform::errors::Unavailable("Message bus: init brpc service error."));

// start the server
const char* ip_for_brpc = addr_.c_str();
Expand Down Expand Up @@ -229,11 +219,16 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
PADDLE_ENFORCE_EQ(
channel.Init(dst_addr_for_brpc, &options), 0,
platform::errors::Unavailable("Message bus: init brpc channel error."));
TheInterceptorMessageService_Stub stub(&channel);
MessageService_Stub stub(&channel);
InterceptorResponse response;
brpc::Controller ctrl;
ctrl.set_log_id(0);
stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL);
if (interceptor_message.ctrl_message()) {
stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL);
} else {
stub.ReceiveInterceptorMessage(&ctrl, &interceptor_message, &response,
NULL);
}
if (!ctrl.Failed()) {
if (response.rst()) {
VLOG(3) << "Message bus: brpc sends success.";
Expand All @@ -248,6 +243,7 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
return false;
}
}

#endif

} // namespace distributed
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/fleet_executor/message_bus.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
!defined(PADDLE_WITH_ASCEND_CL)
#include "brpc/channel.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#endif

#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
Expand Down Expand Up @@ -83,7 +83,7 @@ class MessageBus final {

#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
InterceptorMessageServiceImpl interceptor_message_service_;
MessageServiceImpl message_service_;
// brpc server
brpc::Server server_;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,37 @@
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"

namespace paddle {
namespace distributed {

void InterceptorMessageServiceImpl::InterceptorMessageService(
void MessageServiceImpl::ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
VLOG(3) << "Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
bool flag = GlobalVal<MessageBus>::Get()->DispatchMsgToCarrier(*request);
response->set_rst(flag);
}

void MessageServiceImpl::IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Barrier Service receives a message from rank "
<< request->src_id() << " to rank " << request->dst_id();
GlobalVal<MessageBus>::Get()->IncreaseBarrierCount();
response->set_rst(true);
}

} // namespace distributed
} // namespace paddle
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
namespace paddle {
namespace distributed {

class InterceptorMessageServiceImpl : public TheInterceptorMessageService {
class MessageServiceImpl : public MessageService {
public:
InterceptorMessageServiceImpl() {}
virtual ~InterceptorMessageServiceImpl() {}
virtual void InterceptorMessageService(
MessageServiceImpl() {}
virtual ~MessageServiceImpl() {}
virtual void ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done);
virtual void IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done);
Expand Down