Skip to content

Commit 5d0ce17

Browse files
authored
add time wait for message bus (#37809)
1 parent 075a02d commit 5d0ce17

File tree

3 files changed

+88
-25
lines changed

3 files changed

+88
-25
lines changed

paddle/fluid/distributed/fleet_executor/compute_interceptor.cc

+49-22
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ void ComputeInterceptor::PrepareDeps() {
4646
"Source ComputeInterceptor must run at least one "
4747
"times, but now max_run_times=%ld",
4848
node_->max_run_times()));
49+
in_readys_.emplace(-1,
50+
std::make_pair(std::numeric_limits<int64_t>::max(), 0));
4951
}
5052

5153
// If there is no downstream or every downstream is in different rank,
@@ -55,14 +57,17 @@ void ComputeInterceptor::PrepareDeps() {
5557
}
5658

5759
void ComputeInterceptor::IncreaseReady(int64_t up_id) {
58-
// source node has no upstream, data_is_ready is send by carrier or others
59-
if (is_source_ && up_id == -1) return;
60-
6160
auto it = in_readys_.find(up_id);
6261
PADDLE_ENFORCE_NE(it, in_readys_.end(),
6362
platform::errors::NotFound(
6463
"Cannot find upstream=%lld in in_readys.", up_id));
6564

65+
// source node has no upstream, data_is_ready is send by carrier or others
66+
if (is_source_ && up_id == -1) {
67+
it->second.second = GetTaskNode()->max_run_times();
68+
return;
69+
}
70+
6671
auto max_ready_size = it->second.first;
6772
auto ready_size = it->second.second;
6873
ready_size += 1;
@@ -93,7 +98,11 @@ bool ComputeInterceptor::IsInputReady() {
9398
for (auto& ins : in_readys_) {
9499
auto ready_size = ins.second.second;
95100
// not ready, return false
96-
if (ready_size == 0) return false;
101+
if (ready_size == 0) {
102+
VLOG(3) << "Interceptor " << GetInterceptorId()
103+
<< "'s upstreams aren't all ready.";
104+
return false;
105+
}
97106
}
98107
return true;
99108
}
@@ -103,14 +112,23 @@ bool ComputeInterceptor::CanWriteOutput() {
103112
auto max_buffer_size = outs.second.first;
104113
auto used_size = outs.second.second;
105114
// full, return false
106-
if (used_size == max_buffer_size) return false;
115+
if (used_size == max_buffer_size) {
116+
VLOG(3) << "Interceptor " << GetInterceptorId()
117+
<< "'s out buffer is full.";
118+
return false;
119+
}
107120
}
108121
return true;
109122
}
110123

111124
// only source node need reset
112125
bool ComputeInterceptor::ShouldReset() {
113-
return is_source_ && (step_ == node_->max_run_times());
126+
if (is_source_ && step_ == node_->max_run_times()) {
127+
VLOG(3) << "Interceptor " << GetInterceptorId()
128+
<< " should reset for step: " << step_ << ".";
129+
return true;
130+
}
131+
return false;
114132
}
115133

116134
void ComputeInterceptor::SendDataReadyToDownStream() {
@@ -130,7 +148,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
130148
InterceptorMessage ready_msg;
131149
ready_msg.set_message_type(DATA_IS_READY);
132150
VLOG(3) << "ComputeInterceptor " << interceptor_id_
133-
<< " Send data_is_ready msg to " << down_id;
151+
<< " Send data_is_ready msg to " << down_id
152+
<< " for step: " << step_;
134153
Send(down_id, ready_msg);
135154
}
136155
}
@@ -147,23 +166,43 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
147166
ready_size));
148167
ins.second.second = ready_size;
149168

169+
VLOG(3) << "ComputeInterceptor " << interceptor_id_
170+
<< " Reply data_is_useless msg to " << up_id
171+
<< " for step: " << step_;
172+
if (up_id == -1) return;
173+
150174
InterceptorMessage reply_msg;
151175
reply_msg.set_message_type(DATE_IS_USELESS);
152-
VLOG(3) << "ComputeInterceptor " << interceptor_id_
153-
<< " Reply data_is_useless msg to " << up_id;
154176
Send(up_id, reply_msg);
155177
}
156178
}
157179

158180
void ComputeInterceptor::RunOps() {
159181
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
160-
<< step_ << " time.";
182+
<< step_ + 1 << " time.";
161183
for (auto op : node_->ops()) {
162184
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
163185
}
164186
}
165187

166188
void ComputeInterceptor::Run() {
189+
// If there is no limit, source interceptor can be executed
190+
// an unlimited number of times.
191+
// Now source node can only run max_run_times.
192+
if (ShouldReset()) {
193+
for (auto& out_buff : out_buffs_) {
194+
// buffer is using
195+
if (out_buff.second.second != 0) {
196+
VLOG(3) << "Interceptor " << GetInterceptorId()
197+
<< " out buffer for downstream: " << out_buff.first
198+
<< "'s counter is: " << out_buff.second.second
199+
<< ". Cannot be reset.";
200+
return;
201+
}
202+
}
203+
step_ = 0; // reset
204+
}
205+
167206
while (IsInputReady() && CanWriteOutput() && !ShouldReset()) {
168207
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
169208

@@ -181,18 +220,6 @@ void ComputeInterceptor::Run() {
181220
StopCarrier();
182221
}
183222
}
184-
185-
// If there is no limit, source interceptor can be executed
186-
// an unlimited number of times.
187-
// Now source node can only run max_run_times.
188-
if (ShouldReset()) {
189-
for (auto& out_buff : out_buffs_) {
190-
// buffer is using
191-
if (out_buff.second.second != 0) return;
192-
}
193-
step_ = 0; // reset
194-
return;
195-
}
196223
}
197224

198225
void ComputeInterceptor::ReceivedStop(int64_t up_id) {

paddle/fluid/distributed/fleet_executor/fleet_executor.cc

+9
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,15 @@ void FleetExecutor::Run() {
109109
message_bus_instance.IsInit(), true,
110110
platform::errors::Unavailable("MessageBus has not been init yet."));
111111
carrier_instance.Start();
112+
for (auto* micro_scop : microbatch_scopes_) {
113+
// By default, we should delete all kid scopes after run executor because
114+
// some operators may create local scope when running, such as while_op.
115+
// But when while_op also create a local executor to run it's sub block,
116+
// the sub scopes it created should not be dropped immediately, because
117+
// while_grad_op will use some variables created during while_op run, so
118+
// we need to keep the kids and wait for the outer executor to drop them.
119+
micro_scop->DropKids();
120+
}
112121
}
113122

114123
void FleetExecutor::CopyParameters(int microbatch_id,

paddle/fluid/distributed/fleet_executor/message_bus.cc

+30-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include <chrono>
1616
#include <memory>
17+
#include <set>
1718
#include <thread>
1819

1920
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
@@ -56,11 +57,11 @@ void MessageBus::Init(
5657
bool MessageBus::IsInit() const { return is_init_; }
5758

5859
MessageBus::~MessageBus() {
59-
VLOG(3) << "Message bus releases resource.";
6060
// NOTE: fleet_executor inits carrier before message bus,
6161
// therefore the message bus's destructor will be called first
6262
Carrier& carrier = Carrier::Instance();
6363
carrier.Release();
64+
VLOG(3) << "Message bus releases resource.";
6465
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
6566
!defined(PADDLE_WITH_ASCEND_CL)
6667
server_.Stop(1000);
@@ -90,6 +91,8 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
9091
<< retry_time << " times retries.";
9192
return true;
9293
}
94+
VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
95+
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
9396
}
9497
VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
9598
return false;
@@ -121,16 +124,40 @@ void MessageBus::ListenPort() {
121124
brpc::ServerOptions options;
122125
options.idle_timeout_sec = -1;
123126
int retry_times = 0;
124-
int interval = 1000;
127+
int interval = 100;
125128
while (server_.Start(ip_for_brpc, &options) != 0) {
126129
++retry_times;
127130
LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
128131
<< " times. And will retry after " << interval / 1000
129132
<< " seconds.";
130133
std::this_thread::sleep_for(std::chrono::milliseconds(interval));
131-
interval += 2000;
134+
interval += 500;
132135
}
133136
LOG(INFO) << "Message bus's listen port thread starts successful.";
137+
138+
std::set<int64_t> visit;
139+
InterceptorMessage tmp_msg;
140+
tmp_msg.set_ctrl_message(true);
141+
for (auto pair : interceptor_id_to_rank_) {
142+
if (rank_to_addr_.at(pair.second) == addr_) {
143+
tmp_msg.set_src_id(pair.first);
144+
}
145+
}
146+
for (auto pair : interceptor_id_to_rank_) {
147+
int64_t rank = pair.second;
148+
if (rank_to_addr_.at(rank) == addr_) {
149+
continue;
150+
}
151+
tmp_msg.set_dst_id(pair.first);
152+
if (visit.find(rank) == visit.end()) {
153+
VLOG(3) << "Message bus is testing connection for rank: " << rank << ".";
154+
visit.insert(rank);
155+
while (!Send(tmp_msg)) {
156+
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
157+
}
158+
VLOG(3) << "Message bus has connected to rank: " << rank << ".";
159+
}
160+
}
134161
#else
135162
LOG(WARNING)
136163
<< "Fleet executor's ListenPort() is a fake function when Paddle is "

0 commit comments

Comments
 (0)