Skip to content

Commit 1eceeee

Browse files
wzzjupiotrekobi
authored andcommitted
Fix several bugs for enabling Paddle to train with CINN. (PaddlePaddle#36739)
* Update the content of `test_parallel_executor_run_cinn.py`. * Fix some bugs in the topological sort and `CreateNewSubGraph`. * Update the CINN commit id used by Paddle. * Update the unit test to `add+relu`. * Update according to reviewers' suggestion.
1 parent ec110f3 commit 1eceeee

File tree

9 files changed

+269
-174
lines changed

9 files changed

+269
-174
lines changed

cmake/external/cinn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ add_definitions(-w)
2727
include(ExternalProject)
2828
set(CINN_SOURCE_DIR ${THIRD_PARTY_PATH}/CINN)
2929
# TODO(zhhsplendid): Modify git tag after we have release tag
30-
set(CINN_GIT_TAG e422c01b7875301996a2baf67a14ba61b0e6192a)
30+
set(CINN_GIT_TAG cb030430d76f42f7310d09608f9b22959ecbcb51)
3131
set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} -DPUBLISH_LIBS=ON -DWITH_TESTING=ON)
3232
set(CINN_BUILD_COMMAND $(MAKE) cinnapi -j)
3333
ExternalProject_Add(

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
5252
ResolveOptionConfliction();
5353

5454
AppendPrintGraphPass("graph_viz_pass", "_original_graph");
55+
56+
#ifdef PADDLE_WITH_CINN
57+
if (FLAGS_use_cinn) {
58+
// Note: This pass is used to enable cinn.
59+
AppendPass("build_cinn_pass");
60+
AppendPrintGraphPass("graph_viz_pass", "_build_cinn_graph");
61+
}
62+
#endif
63+
5564
AppendPassWithCheck(strategy_.enable_sequential_execution_,
5665
"sequential_execution_pass");
5766
AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass");
@@ -74,13 +83,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
7483
// Note: This pass is used to check whether the multi_device_graph is right.
7584
AppendPass("multi_devices_check_pass");
7685

77-
#ifdef PADDLE_WITH_CINN
78-
if (FLAGS_use_cinn) {
79-
// Note: This pass is used to enable cinn.
80-
AppendPass("build_cinn_pass");
81-
}
82-
#endif
83-
8486
SetCollectiveContext();
8587
}
8688

paddle/fluid/framework/paddle2cinn/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc)
2-
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector cinn_compiler)
2+
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce)
33
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
4-
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph graph_helper transform_desc cinn)
4+
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn)
55
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)
66

77
cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)

0 commit comments

Comments
 (0)