@@ -46,6 +46,8 @@ void ComputeInterceptor::PrepareDeps() {
46
46
" Source ComputeInterceptor must run at least one "
47
47
" times, but now max_run_times=%ld" ,
48
48
node_->max_run_times ()));
49
+ in_readys_.emplace (-1 ,
50
+ std::make_pair (std::numeric_limits<int64_t >::max (), 0 ));
49
51
}
50
52
51
53
// If there is no downstream or every downstream is in different rank,
@@ -55,14 +57,17 @@ void ComputeInterceptor::PrepareDeps() {
55
57
}
56
58
57
59
void ComputeInterceptor::IncreaseReady (int64_t up_id) {
58
- // source node has no upstream, data_is_ready is send by carrier or others
59
- if (is_source_ && up_id == -1 ) return ;
60
-
61
60
auto it = in_readys_.find (up_id);
62
61
PADDLE_ENFORCE_NE (it, in_readys_.end (),
63
62
platform::errors::NotFound (
64
63
" Cannot find upstream=%lld in in_readys." , up_id));
65
64
65
+ // source node has no upstream, data_is_ready is send by carrier or others
66
+ if (is_source_ && up_id == -1 ) {
67
+ it->second .second = GetTaskNode ()->max_run_times ();
68
+ return ;
69
+ }
70
+
66
71
auto max_ready_size = it->second .first ;
67
72
auto ready_size = it->second .second ;
68
73
ready_size += 1 ;
@@ -93,7 +98,11 @@ bool ComputeInterceptor::IsInputReady() {
93
98
for (auto & ins : in_readys_) {
94
99
auto ready_size = ins.second .second ;
95
100
// not ready, return false
96
- if (ready_size == 0 ) return false ;
101
+ if (ready_size == 0 ) {
102
+ VLOG (3 ) << " Interceptor " << GetInterceptorId ()
103
+ << " 's upstreams aren't all ready." ;
104
+ return false ;
105
+ }
97
106
}
98
107
return true ;
99
108
}
@@ -103,14 +112,23 @@ bool ComputeInterceptor::CanWriteOutput() {
103
112
auto max_buffer_size = outs.second .first ;
104
113
auto used_size = outs.second .second ;
105
114
// full, return false
106
- if (used_size == max_buffer_size) return false ;
115
+ if (used_size == max_buffer_size) {
116
+ VLOG (3 ) << " Interceptor " << GetInterceptorId ()
117
+ << " 's out buffer is full." ;
118
+ return false ;
119
+ }
107
120
}
108
121
return true ;
109
122
}
110
123
111
124
// only source node need reset
112
125
bool ComputeInterceptor::ShouldReset () {
113
- return is_source_ && (step_ == node_->max_run_times ());
126
+ if (is_source_ && step_ == node_->max_run_times ()) {
127
+ VLOG (3 ) << " Interceptor " << GetInterceptorId ()
128
+ << " should reset for step: " << step_ << " ." ;
129
+ return true ;
130
+ }
131
+ return false ;
114
132
}
115
133
116
134
void ComputeInterceptor::SendDataReadyToDownStream () {
@@ -130,7 +148,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
130
148
InterceptorMessage ready_msg;
131
149
ready_msg.set_message_type (DATA_IS_READY);
132
150
VLOG (3 ) << " ComputeInterceptor " << interceptor_id_
133
- << " Send data_is_ready msg to " << down_id;
151
+ << " Send data_is_ready msg to " << down_id
152
+ << " for step: " << step_;
134
153
Send (down_id, ready_msg);
135
154
}
136
155
}
@@ -147,23 +166,43 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
147
166
ready_size));
148
167
ins.second .second = ready_size;
149
168
169
+ VLOG (3 ) << " ComputeInterceptor " << interceptor_id_
170
+ << " Reply data_is_useless msg to " << up_id
171
+ << " for step: " << step_;
172
+ if (up_id == -1 ) return ;
173
+
150
174
InterceptorMessage reply_msg;
151
175
reply_msg.set_message_type (DATE_IS_USELESS);
152
- VLOG (3 ) << " ComputeInterceptor " << interceptor_id_
153
- << " Reply data_is_useless msg to " << up_id;
154
176
Send (up_id, reply_msg);
155
177
}
156
178
}
157
179
158
180
void ComputeInterceptor::RunOps () {
159
181
VLOG (3 ) << " ComputeInterceptor " << interceptor_id_ << " running ops for the "
160
- << step_ << " time." ;
182
+ << step_ + 1 << " time." ;
161
183
for (auto op : node_->ops ()) {
162
184
op->Run (*microbatch_scopes_[step_ % node_->max_run_times ()], place_);
163
185
}
164
186
}
165
187
166
188
void ComputeInterceptor::Run () {
189
+ // If there is no limit, source interceptor can be executed
190
+ // an unlimited number of times.
191
+ // Now source node can only run max_run_times.
192
+ if (ShouldReset ()) {
193
+ for (auto & out_buff : out_buffs_) {
194
+ // buffer is using
195
+ if (out_buff.second .second != 0 ) {
196
+ VLOG (3 ) << " Interceptor " << GetInterceptorId ()
197
+ << " out buffer for downstream: " << out_buff.first
198
+ << " 's counter is: " << out_buff.second .second
199
+ << " . Cannot be reset." ;
200
+ return ;
201
+ }
202
+ }
203
+ step_ = 0 ; // reset
204
+ }
205
+
167
206
while (IsInputReady () && CanWriteOutput () && !ShouldReset ()) {
168
207
VLOG (3 ) << " id=" << GetInterceptorId () << " ComputeInterceptor running" ;
169
208
@@ -181,18 +220,6 @@ void ComputeInterceptor::Run() {
181
220
StopCarrier ();
182
221
}
183
222
}
184
-
185
- // If there is no limit, source interceptor can be executed
186
- // an unlimited number of times.
187
- // Now source node can only run max_run_times.
188
- if (ShouldReset ()) {
189
- for (auto & out_buff : out_buffs_) {
190
- // buffer is using
191
- if (out_buff.second .second != 0 ) return ;
192
- }
193
- step_ = 0 ; // reset
194
- return ;
195
- }
196
223
}
197
224
198
225
void ComputeInterceptor::ReceivedStop (int64_t up_id) {
0 commit comments