Skip to content

[fleet_executor] Add task loop thread pool #38420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ else()
set(BRPC_DEPS "")
endif()

cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog)

cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
executor_gc_helper gflags glog ${BRPC_DEPS})
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog ${BRPC_DEPS})

if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
Expand Down
102 changes: 24 additions & 78 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,17 @@ void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
place_ = place;
root_scope_ = root_scope;
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);

// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();

CreateInterceptors();
is_init_ = true;
}

void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.

for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
InterceptorMessage stop_msg;
// source node STOP is send by carrier, so set src_id=-1
stop_msg.set_src_id(-1);
stop_msg.set_dst_id(id);
stop_msg.set_message_type(STOP);
Send(stop_msg);
}

// TODO(wangxi): Maybe need a better to use thread.
for (auto& interceptor : interceptor_idx_to_interceptor_) {
interceptor.second->Join();
}
}
void Carrier::Release() {}

Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }

Expand All @@ -75,25 +62,18 @@ bool Carrier::EnqueueInterceptorMessage(
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
// for barrier
msg_bus_->IncreaseBarrierCount();
} else {
{
std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_);
if (creating_interceptors_) {
std::unique_lock<std::mutex> lock_message(tmp_message_mutex_);
// Cannot handle the message to interceptor since interceptors
// are still under creating. Will enqueue into a tmp stack.
VLOG(3) << "Receiving message while creating interceptors.";
message_tmp_.emplace_back(interceptor_message);
return true;
}
}
int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id);
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
}
return true;
}

void Carrier::Barrier() { msg_bus_->Barrier(); }

Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(),
Expand All @@ -109,6 +89,11 @@ void Carrier::Wait() {
cond_var_.wait(lock);
}

void Carrier::WakeUp() {
// probably double notify, but ok for ut
cond_var_.notify_all();
}

void Carrier::Start() {
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
Expand All @@ -126,12 +111,11 @@ void Carrier::Start() {
start_msg.set_message_type(DATA_IS_READY);
Send(start_msg);
}
// TODO(wangxi): async step
Wait();
dev_ctx_->Wait();
}

std::condition_variable& Carrier::GetCondVar() { return cond_var_; }

bool Carrier::IsInit() const { return is_init_; }

int64_t Carrier::GetRank(int64_t interceptor_id) const {
Expand Down Expand Up @@ -183,51 +167,19 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id should be unique.",
interceptor_id));
interceptor->RegisterCarrier(this);

// TODO(fleet_exe dev): get loop
auto* loop = thread_pool_.GetLoop(interceptor_id % thread_num_);
PADDLE_ENFORCE_NOT_NULL(
loop, platform::errors::Fatal("thread task loop must not null"));
interceptor->RegisterTaskLoop(loop);

auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
return ptr;
}

void Carrier::SetCreatingFlag(bool flag) {
// set the creating flag
creating_flag_mutex_.lock();
VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_
<< " to " << flag << ".";
creating_interceptors_ = flag;
creating_flag_mutex_.unlock();
if (!flag) {
for (auto& pair : interceptor_idx_to_interceptor_) {
// update the source interceptor id
if (std::find(source_interceptor_ids_.begin(),
source_interceptor_ids_.end(),
pair.first) == source_interceptor_ids_.end()) {
auto task = pair.second->GetTaskNode();
if (task != nullptr && task->upstream().empty()) {
source_interceptor_ids_.emplace_back(pair.first);
}
}
}
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages();
}
}

void Carrier::HandleTmpMessages() {
// NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this
// `HandleTmpMessages` method, the creating_interceptors_ flag
// must be false, therefore, there won't have conflict with the
// lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage`
// on the same thread.
std::unique_lock<std::mutex> lock(tmp_message_mutex_);
VLOG(3) << "Carrier has received " << message_tmp_.size()
<< " messages during creating interceptors.";
for (const auto& msg : message_tmp_) {
EnqueueInterceptorMessage(msg);
}
message_tmp_.clear();
}

static std::shared_ptr<framework::GarbageCollector> GetGC(
const platform::Place& place) {
int64_t max_memory_size = framework::GetEagerDeletionThreshold();
Expand Down Expand Up @@ -285,12 +237,6 @@ void Carrier::CreateInterceptors() {
source_interceptor_ids_.emplace_back(interceptor_id);
}
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_flag_mutex_.lock();
creating_interceptors_ = false;
creating_flag_mutex_.unlock();
HandleTmpMessages();
}

} // namespace distributed
Expand Down
26 changes: 12 additions & 14 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
Expand All @@ -47,7 +48,11 @@ class Carrier final {
Carrier() = default;
Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {}
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
}
~Carrier();
void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
Expand All @@ -56,6 +61,7 @@ class Carrier final {

void Release();
void Wait();
void WakeUp();

// Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
Expand All @@ -67,32 +73,25 @@ class Carrier final {
Interceptor* SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor>);

void SetCreatingFlag(bool flag);
void SetCreatingFlag(bool flag) {}
void SetMsgBus(const std::shared_ptr<MessageBus>& msg_bus) {
msg_bus_ = msg_bus;
}

std::condition_variable& GetCondVar();

void Start();

bool IsInit() const;

bool Send(const InterceptorMessage& msg);

// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops.
std::mutex run;
void Barrier();

private:
DISABLE_COPY_AND_ASSIGN(Carrier);

// create each Interceptor
void CreateInterceptors();

void HandleTmpMessages();

int64_t GetRank(int64_t interceptor_id) const;

// interceptor logic id to actually interceptor
Expand All @@ -101,10 +100,6 @@ class Carrier final {

std::vector<int64_t> source_interceptor_ids_;

std::vector<InterceptorMessage> message_tmp_{};
std::mutex tmp_message_mutex_;
bool creating_interceptors_{true};
std::mutex creating_flag_mutex_;
bool is_init_{false};

std::mutex running_mutex_;
Expand All @@ -118,6 +113,9 @@ class Carrier final {
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;

int thread_num_;
TaskLoopThreadPool thread_pool_;
};

} // namespace distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}

void ComputeInterceptor::RunOps() {
std::unique_lock<std::mutex> lock(carrier_->run);
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ + 1 << " time.";
for (auto op : node_->ops()) {
Expand Down Expand Up @@ -198,6 +197,7 @@ void ComputeInterceptor::Run() {
if (is_last_ && (step_ % node_->max_run_times() == 0)) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " is stopping carrier.";
// FIXME(wangxi): with multi sink interceptor
StopCarrier();
}
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ void FleetExecutor::Init(
CreateCarrier();
InitCarrier();
InitMessageBus();

// refine this? wait all carrier ready
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier()->Barrier();
}

void FleetExecutor::InitCarrier() {
Expand Down
Loading