Skip to content

Commit 15d2e3c

Browse files
authored
Merge branch 'develop' into dropout_opt_clean_BcInT
2 parents 58abfa2 + 3206fa8 commit 15d2e3c

File tree

347 files changed

+5245
-2858
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

347 files changed

+5245
-2858
lines changed

cmake/operators.cmake

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ function(op_library TARGET)
6262
set(hip_cc_srcs)
6363
set(xpu_cc_srcs)
6464
set(xpu_kp_cc_srcs)
65-
set(mlu_cc_srcs)
6665
set(cudnn_cu_cc_srcs)
6766
set(miopen_cu_cc_srcs)
6867
set(cudnn_cu_srcs)
@@ -307,9 +306,8 @@ function(op_library TARGET)
307306
# Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
308307
if(WITH_UNITY_BUILD AND op_library_UNITY)
309308
# Combine the cc source files.
310-
compose_unity_target_sources(
311-
${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs}
312-
${mlu_cc_srcs})
309+
compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs}
310+
${mkldnn_cc_srcs} ${xpu_cc_srcs})
313311
if(TARGET ${UNITY_TARGET})
314312
# If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
315313
target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources})
@@ -325,7 +323,7 @@ function(op_library TARGET)
325323
else()
326324
cc_library(
327325
${TARGET}
328-
SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${mlu_cc_srcs}
326+
SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs}
329327
DEPS ${op_library_DEPS} ${op_common_deps})
330328
endif()
331329
endif()
@@ -337,7 +335,6 @@ function(op_library TARGET)
337335
list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
338336
list(LENGTH xpu_cc_srcs xpu_cc_srcs_len)
339337
list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len)
340-
list(LENGTH mlu_cc_srcs mlu_cc_srcs_len)
341338

342339
# Define operators that don't need pybind here.
343340
foreach(
@@ -562,7 +559,6 @@ function(register_operators)
562559
"*_op.cc")
563560
string(REPLACE "_mkldnn" "" OPS "${OPS}")
564561
string(REPLACE "_xpu" "" OPS "${OPS}")
565-
string(REPLACE "_mlu" "" OPS "${OPS}")
566562
string(REPLACE ".cc" "" OPS "${OPS}")
567563
list(REMOVE_DUPLICATES OPS)
568564
list(LENGTH register_operators_DEPS register_operators_DEPS_len)

cmake/phi_header.cmake

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,24 @@ set(PADDLE_INFERENCE_INSTALL_DIR
1717

1818
function(phi_header_path_compat TARGET_PATH)
1919
message(STATUS "phi header path compat processing: ${TARGET_PATH}")
20-
file(GLOB HEADERS "${TARGET_PATH}/*" "*.h")
21-
foreach(header ${HEADERS})
22-
if(${header} MATCHES ".*.h$")
23-
file(READ ${header} HEADER_CONTENT)
24-
string(REPLACE "paddle/phi/" "paddle/include/experimental/phi/"
25-
HEADER_CONTENT "${HEADER_CONTENT}")
26-
string(REPLACE "paddle/fluid/platform/"
27-
"paddle/include/experimental/phi/" HEADER_CONTENT
28-
"${HEADER_CONTENT}")
29-
string(REPLACE "paddle/utils/" "paddle/include/experimental/utils/"
30-
HEADER_CONTENT "${HEADER_CONTENT}")
31-
file(WRITE ${header} "${HEADER_CONTENT}")
32-
message(STATUS "phi header path compat processing complete: ${header}")
33-
endif()
34-
endforeach()
20+
string(FIND ${TARGET_PATH} "experimental" pos)
21+
if(pos GREATER 1)
22+
file(GLOB HEADERS "${TARGET_PATH}/*" "*.h")
23+
foreach(header ${HEADERS})
24+
if(${header} MATCHES ".*.h$")
25+
file(READ ${header} HEADER_CONTENT)
26+
string(REPLACE "paddle/phi/" "paddle/include/experimental/phi/"
27+
HEADER_CONTENT "${HEADER_CONTENT}")
28+
string(REPLACE "paddle/fluid/platform/"
29+
"paddle/include/experimental/phi/" HEADER_CONTENT
30+
"${HEADER_CONTENT}")
31+
string(REPLACE "paddle/utils/" "paddle/include/experimental/utils/"
32+
HEADER_CONTENT "${HEADER_CONTENT}")
33+
file(WRITE ${header} "${HEADER_CONTENT}")
34+
message(STATUS "phi header path compat processing complete: ${header}")
35+
endif()
36+
endforeach()
37+
endif()
3538
endfunction()
3639

3740
phi_header_path_compat(
@@ -48,7 +51,16 @@ phi_header_path_compat(
4851
${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/common)
4952
phi_header_path_compat(
5053
${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/phi/core)
51-
phi_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/)
54+
55+
# NOTE(liuyuanle): In inference lib, no need include paddle/utils/pybind.h, so we delete this.
56+
file(READ
57+
${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/extension.h
58+
HEADER_CONTENT)
59+
string(REGEX REPLACE "#if !defined\\(PADDLE_ON_INFERENCE\\).*#endif" ""
60+
HEADER_CONTENT "${HEADER_CONTENT}")
61+
file(WRITE
62+
${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/extension.h
63+
"${HEADER_CONTENT}")
5264

5365
# In order to be compatible with the original behavior, the header file name needs to be changed
5466
file(RENAME

paddle/extension.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ limitations under the License. */
1616

1717
// All paddle apis in C++ frontend
1818
#include "paddle/phi/api/all.h"
19-
// Python bindings for the C++ frontend (includes Python.h)
2019
#if !defined(PADDLE_ON_INFERENCE) && !defined(PADDLE_NO_PYTHON)
20+
// Python bindings for the C++ frontend (includes Python.h)
2121
#include "paddle/utils/pybind.h"
2222
#endif
2323
// For initialization of DeviceContextPool and MemoryMethod

paddle/fluid/distributed/fleet_executor/carrier.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <algorithm>
1818
#include <vector>
1919

20+
#include "gflags/gflags.h"
2021
#include "paddle/fluid/distributed/fleet_executor/global.h"
2122
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
2223
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
@@ -237,12 +238,10 @@ bool Carrier::Send(const InterceptorMessage& msg) {
237238
VLOG(3) << "Send a message from interceptor " << src_id
238239
<< " to interceptor " << dst_id << ", which are in the same ranks.";
239240
return EnqueueInterceptorMessage(msg);
240-
} else {
241-
VLOG(3) << "Send a message from interceptor " << src_id
242-
<< " to interceptor " << dst_id
243-
<< ", which are in different ranks.";
244-
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
245241
}
242+
VLOG(3) << "Send a message from interceptor " << src_id << " to interceptor "
243+
<< dst_id << ", which are in different ranks.";
244+
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
246245
}
247246

248247
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,

paddle/fluid/distributed/fleet_executor/compute_interceptor.cc

Lines changed: 75 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -158,25 +158,47 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
158158
}
159159

160160
bool ComputeInterceptor::IsInputReady() {
161-
for (int64_t i = 0; i < node_->max_run_times(); ++i) {
161+
std::map<int64_t, bool> scope_id_to_finish_flag;
162+
if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
163+
scope_id_to_finish_flag =
164+
gen_step_to_scope_id_to_finish_flag_.begin()->second;
165+
VLOG(3) << "Is Input Ready in gen step "
166+
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
167+
}
168+
int64_t num_micro_step =
169+
(num_micro_step_ == -1 ? node_->max_run_times() : num_micro_step_);
170+
int64_t start_micro_step = (start_micro_step_ == -1 ? 0 : start_micro_step_);
171+
for (int64_t i = start_micro_step; i < start_micro_step + num_micro_step;
172+
++i) {
162173
bool flag = true;
163174
for (auto& ins : in_readys_) {
164175
auto ready_size_map = ins.second.second;
165176
flag = flag && (ready_size_map.at(i) != 0);
166177
}
167178
if (flag) {
168-
for (auto iter : scope_id_to_finish_flag_) {
169-
if (iter.first == i) {
170-
break;
171-
} else if (!iter.second) {
172-
VLOG(3) << "The previous scope is not ready, waiting for the "
173-
"previous scope "
174-
<< iter.first;
175-
return false;
179+
if (scope_id_to_finish_flag.empty()) {
180+
cur_scope_id_ = i;
181+
return true;
182+
} else if (scope_id_to_finish_flag.find(i) !=
183+
scope_id_to_finish_flag.end()) {
184+
for (auto iter : scope_id_to_finish_flag) {
185+
if (iter.first == i) {
186+
break;
187+
} else if (!iter.second) {
188+
VLOG(3) << "The previous scope is not ready, waiting for the "
189+
"previous scope "
190+
<< iter.first << " in gen_step "
191+
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
192+
return false;
193+
}
176194
}
195+
cur_scope_id_ = i;
196+
return true;
197+
} else {
198+
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
199+
<< " is larger than gen_step "
200+
<< gen_step_to_scope_id_to_finish_flag_.begin()->first;
177201
}
178-
cur_scope_id_ = i;
179-
return true;
180202
} else {
181203
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
182204
<< "'s upstreams aren't all ready.";
@@ -203,6 +225,16 @@ bool ComputeInterceptor::CanWriteOutput() {
203225
}
204226

205227
void ComputeInterceptor::SendDataReadyToDownStream() {
228+
bool need_send_vars = !(node_->vars_to_dtype().empty());
229+
InterceptorMessage ready_msg;
230+
ready_msg.set_start_micro_step(start_micro_step_);
231+
ready_msg.set_num_micro_step(num_micro_step_);
232+
if (need_send_vars) {
233+
ready_msg = PrepareVarsMsg();
234+
} else {
235+
ready_msg.set_message_type(DATA_IS_READY);
236+
ready_msg.set_scope_idx(cur_scope_id_);
237+
}
206238
for (auto& outs : out_buffs_) {
207239
auto down_id = outs.first;
208240
auto max_buff_size = outs.second.first;
@@ -221,13 +253,17 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
221253
}
222254
outs.second.second = used_size;
223255

224-
InterceptorMessage ready_msg;
225-
ready_msg.set_message_type(DATA_IS_READY);
226-
ready_msg.set_scope_idx(cur_scope_id_);
227-
VLOG(3) << "ComputeInterceptor " << interceptor_id_
228-
<< " Send data_is_ready msg to " << down_id
229-
<< " in scope: " << cur_scope_id_;
230-
Send(down_id, ready_msg);
256+
if (need_send_vars) {
257+
VLOG(3) << "ComputeInterceptor " << interceptor_id_
258+
<< " Send data_with_vars msg to " << down_id
259+
<< " in scope: " << cur_scope_id_;
260+
Send(down_id, ready_msg);
261+
} else {
262+
VLOG(3) << "ComputeInterceptor " << interceptor_id_
263+
<< " Send data_is_ready msg to " << down_id
264+
<< " in scope: " << cur_scope_id_;
265+
Send(down_id, ready_msg);
266+
}
231267
}
232268
}
233269

@@ -289,13 +325,21 @@ void ComputeInterceptor::Run() {
289325

290326
RunOps();
291327

292-
if (!scope_id_to_finish_flag_.empty()) {
328+
if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
329+
auto iter = gen_step_to_scope_id_to_finish_flag_.begin();
330+
VLOG(3) << "id=" << GetInterceptorId()
331+
<< " ComputeInterceptor running in scope " << cur_scope_id_
332+
<< " with gen_step " << iter->first;
333+
auto& scope_id_to_finish_flag = iter->second;
293334
PADDLE_ENFORCE_NE(
294-
scope_id_to_finish_flag_.find(cur_scope_id_),
295-
scope_id_to_finish_flag_.end(),
335+
scope_id_to_finish_flag.find(cur_scope_id_),
336+
scope_id_to_finish_flag.end(),
296337
platform::errors::NotFound(
297338
"Can not find scope %ld in scope_id_to_finish", cur_scope_id_));
298-
scope_id_to_finish_flag_.erase(cur_scope_id_);
339+
scope_id_to_finish_flag.erase(cur_scope_id_);
340+
if (scope_id_to_finish_flag.empty()) {
341+
gen_step_to_scope_id_to_finish_flag_.erase(iter);
342+
}
299343
}
300344

301345
// send to downstream and increase buff used
@@ -310,6 +354,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
310354
VLOG(3) << "Compute interceptor " << interceptor_id_
311355
<< " receive data_is_ready " << msg.src_id() << " "
312356
<< msg.scope_idx() << " ";
357+
start_micro_step_ = msg.start_micro_step();
358+
num_micro_step_ = msg.num_micro_step();
313359
IncreaseReady(msg.src_id(), msg.scope_idx());
314360
Run();
315361
} else if (msg.message_type() == DATA_IS_USELESS) {
@@ -327,10 +373,14 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
327373
Run();
328374
} else if (msg.message_type() == START_LOOP) {
329375
VLOG(3) << "Compute interceptor " << interceptor_id_
330-
<< " receive start_loop " << msg.src_id() << " " << msg.scope_idx()
331-
<< " ";
376+
<< " receive start_loop " << msg.src_id() << " in scope "
377+
<< msg.scope_idx() << " with gen_step " << msg.gen_step();
378+
start_micro_step_ = msg.start_micro_step();
379+
num_micro_step_ = msg.num_micro_step();
332380
IncreaseReady(msg.src_id(), msg.scope_idx());
333-
scope_id_to_finish_flag_.emplace(msg.scope_idx(), false);
381+
int64_t gen_step = msg.gen_step();
382+
gen_step_to_scope_id_to_finish_flag_[gen_step].emplace(msg.scope_idx(),
383+
false);
334384
Run();
335385
}
336386
}

paddle/fluid/distributed/fleet_executor/compute_interceptor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ class ComputeInterceptor : public Interceptor {
5252

5353
bool IsInputReady();
5454
bool CanWriteOutput();
55-
std::map<int64_t, bool> scope_id_to_finish_flag_;
55+
std::map<int64_t, std::map<int64_t, bool>>
56+
gen_step_to_scope_id_to_finish_flag_;
57+
int64_t start_micro_step_{-1};
58+
int64_t num_micro_step_{-1};
5659
};
5760

5861
} // namespace distributed

0 commit comments

Comments
 (0)