Skip to content

Commit 09ec077

Browse files
Use helper functions in two more fuse pass tests.
1 parent 9f1cf86 commit 09ec077

File tree

6 files changed

+211
-496
lines changed

6 files changed

+211
-496
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ if (WITH_MKLDNN)
163163
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
164164
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
165165
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
166-
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
167-
cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass)
166+
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util)
167+
cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util)
168168
cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util)
169169
set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context)
170170
if (WITH_GPU)

paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
1818
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h"
1919
#include "paddle/fluid/framework/op_desc.h"
20+
#include "paddle/fluid/framework/op_version_registry.h"
2021
#include "paddle/fluid/framework/program_desc.h"
2122
#include "paddle/fluid/platform/errors.h"
2223

@@ -63,9 +64,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) {
6364
// No fusion in this attribute configuration
6465
constexpr int removed_nodes_count = 0;
6566

66-
EXPECT_THROW(
67-
test::RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
68-
paddle::platform::EnforceNotMet);
67+
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
68+
"act_y", removed_nodes_count),
69+
paddle::platform::EnforceNotMet);
6970
}
7071

7172
TEST(FuseBatchNormActOneDNNPass, FuseIsTest) {
@@ -83,8 +84,8 @@ TEST(FuseBatchNormActOneDNNPass, FuseIsTest) {
8384
Graph graph(prog);
8485
constexpr int removed_nodes_count = 2;
8586

86-
EXPECT_TRUE(
87-
test::RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count));
87+
EXPECT_TRUE(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
88+
"act_y", removed_nodes_count));
8889
EXPECT_TRUE(test::AssertOpsCount(graph, {{"batch_norm", 1}, {"relu", 0}}));
8990

9091
for (const auto* node : graph.Nodes()) {
@@ -121,9 +122,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) {
121122
// No fusion in this attribute configuration
122123
constexpr int removed_nodes_count = 0;
123124

124-
EXPECT_THROW(
125-
test::RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
126-
paddle::platform::EnforceNotMet);
125+
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
126+
"act_y", removed_nodes_count),
127+
paddle::platform::EnforceNotMet);
127128
}
128129

129130
TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
@@ -147,9 +148,9 @@ TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
147148
// No fusion in this attribute configuration
148149
constexpr int removed_nodes_count = 0;
149150

150-
EXPECT_THROW(
151-
test::RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
152-
paddle::platform::EnforceNotMet);
151+
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
152+
"act_y", removed_nodes_count),
153+
paddle::platform::EnforceNotMet);
153154
}
154155

155156
TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
@@ -174,9 +175,15 @@ TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
174175
// No fusion in this attribute configuration
175176
constexpr int removed_nodes_count = 0;
176177

177-
EXPECT_THROW(
178-
test::RunPassAndAssert(&graph, "x", "act_y", removed_nodes_count),
179-
paddle::platform::EnforceNotMet);
178+
EXPECT_THROW(test::RunPassAndAssert(&graph, "batch_norm_act_fuse_pass", "x",
179+
"act_y", removed_nodes_count),
180+
paddle::platform::EnforceNotMet);
181+
}
182+
183+
TEST(FuseBatchNormActOneDNNPass, pass_op_version_check) {
184+
ASSERT_TRUE(
185+
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
186+
.IsPassCompatible("batch_norm_act_fuse_pass"));
180187
}
181188

182189
} // namespace ir

0 commit comments

Comments
 (0)