Skip to content

[fleet_executor] Framework for message and manager part. #36966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ if (WITH_PSCORE)

include(external/libmct) # download, build, install libmct
list(APPEND third_party_deps extern_libmct)

if (WITH_HETERPS)
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
Expand Down
22 changes: 20 additions & 2 deletions paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
proto_library(fleet_executor_desc_proto SRCS fleet_executor_desc.proto)
cc_library(fleet_executor SRCS fleet_executor.cc DEPS fleet_executor_desc_proto)

if(WITH_PYTHON)
py_proto_compile(fleet_executor_desc_py_proto SRCS fleet_executor_desc.proto)
endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto)

if(WITH_DISTRIBUTE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set(BRPC_DEPS brpc)
else()
set(BRPC_DEPS "")
endif()

cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc
interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS fleet_executor_desc_proto interceptor_message_proto ${BRPC_DEPS})

if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
43 changes: 43 additions & 0 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"

namespace paddle {
namespace distributed {

Carrier::Carrier(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
// init
}

Carrier::~Carrier() {
// destroy
}

bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor
return true;
}

void Carrier::CreateInterceptors() {
// create each Interceptor
}

} // namespace distributed
} // namespace paddle
60 changes: 60 additions & 0 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <unordered_map>

#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace distributed {

class Interceptor;
class TaskNode;
class InterceptorMessageServiceImpl;

class Carrier final {
public:
Carrier() = delete;

Carrier(const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node);

~Carrier();

// Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);

DISABLE_COPY_AND_ASSIGN(Carrier);

private:
// create each Interceptor
void CreateInterceptors();

// get interceptor based on the interceptor id
Interceptor* GetInterceptor(int64_t interceptor_id);

// interceptor logic id to the Nodes info
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;

// interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
};

} // namespace distributed
} // namespace paddle
10 changes: 10 additions & 0 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,15 @@ void FleetExecutor::Release() {
// Release
}

std::shared_ptr<Carrier> FleetExecutor::GetCarrier() {
// get carrier
return nullptr;
}

std::shared_ptr<MessageBus> FleetExecutor::GetMessageBus() {
// get message bus
return nullptr;
}

} // namespace distributed
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once
#include <memory>

#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"

Expand All @@ -24,6 +25,8 @@ class ProgramDesc;

namespace distributed {
class RuntimeGraph;
class Carrier;
class MessageBus;

class FleetExecutor final {
public:
Expand All @@ -33,11 +36,15 @@ class FleetExecutor final {
void Init(const paddle::framework::ProgramDesc& program_desc);
void Run();
void Release();
static std::shared_ptr<Carrier> GetCarrier();
static std::shared_ptr<MessageBus> GetMessageBus();

private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_;
static std::shared_ptr<Carrier> global_carrier_;
static std::shared_ptr<MessageBus> global_message_bus_;
};

} // namespace distributed
Expand Down
46 changes: 46 additions & 0 deletions paddle/fluid/distributed/fleet_executor/interceptor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"

namespace paddle {
namespace distributed {

Interceptor::Interceptor(int64_t interceptor_id_, TaskNode* node) {
// init
}

int64_t Interceptor::GetInterceptorId() const {
// return the interceptor id
return 0;
}

bool Interceptor::EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
return true;
}

void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message
}

bool Interceptor::FetchRemoteMailbox() {
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
return true;
}

} // namespace distributed
} // namespace paddle
83 changes: 83 additions & 0 deletions paddle/fluid/distributed/fleet_executor/interceptor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <condition_variable>
#include <map>
#include <memory>
#include <queue>
#include <thread>
#include <vector>

#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace distributed {

class TaskNode;

class Interceptor {
public:
Interceptor() = delete;

Interceptor(int64_t interceptor_id_, TaskNode* node);

virtual ~Interceptor() = default;

// return the interceptor id
int64_t GetInterceptorId() const;

// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
bool EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message);

DISABLE_COPY_AND_ASSIGN(Interceptor);

private:
// pool the local mailbox, parse the Message
void PoolTheMailbox();

// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
bool FetchRemoteMailbox();

// interceptor id, handed from above layer
int64_t interceptor_id_;

// node need to be handled by this interceptor
TaskNode* node_;

// mutex to control read/write conflict for remote mailbox
std::mutex remote_mailbox_mutex_;

// interceptor runs PoolTheMailbox() function to poll local mailbox
std::thread interceptor_thread_;

// conditional variable for blocking the thread when
// fetch an empty remote mailbox
std::condition_variable cond_var_;

// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
std::queue<InterceptorMessage> remote_mailbox_;

// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std::queue<InterceptorMessage> local_mailbox_;
};

} // namespace distributed
} // namespace paddle
40 changes: 40 additions & 0 deletions paddle/fluid/distributed/fleet_executor/interceptor_message.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto2";
package paddle.distributed;
option cc_generic_services = true;
option cc_enable_arenas = true;

enum MessageType {
STOP = 1; // STOP an Interceptor
DATA_IS_READY = 2; // upstream data is ready
DATE_IS_USELESS = 3; // downstream has used the data
ERROR = 4; // current Interceptor encounters error
RESET = 5; // reset the status
}

message InterceptorMessage {
optional int64 src_id = 1 [ default = 0 ];
optional int64 dst_id = 2 [ default = 0 ];
optional MessageType message_type = 3 [ default = RESET ];
optional bool ctrl_message = 4 [ default = false ];
}

message InterceptorResponse { optional bool rst = 1 [ default = false ]; }

service TheInterceptorMessageService {
rpc InterceptorMessageService(InterceptorMessage)
returns (InterceptorResponse);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"

namespace paddle {
namespace distributed {

void InterceptorMessageServiceImpl::InterceptorMessageService(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
// receive msg
}

} // namespace distributed
} // namespace paddle
#endif
#endif
Loading