Skip to content

Commit a67f2d3

Browse files
jakpiaselidanqing-vv
authored andcommitted
OneDNN hardswish integration (#30211)
1 parent d44d173 commit a67f2d3

14 files changed

+653
-14
lines changed

paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,11 @@ REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass)
135135
paddle::framework::compatible::OpVersionComparatorCombination()
136136
.LE("conv2d", 1)
137137
.EQ("swish", 0));
138+
139+
REGISTER_PASS(conv_hard_swish_mkldnn_fuse_pass,
140+
paddle::framework::ir::Conv2DHardSwishFusePass);
141+
REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass)
142+
.AddCombination(
143+
paddle::framework::compatible::OpVersionComparatorCombination()
144+
.LE("conv2d", 1)
145+
.EQ("hard_swish", 0));

paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ class Conv2DSwishFusePass : public ConvActivationFusePass {
6060
public:
6161
std::string activation_type() const { return "swish"; }
6262
};
63+
/*
64+
* Fuse Conv and HardSwish class
65+
*/
66+
class Conv2DHardSwishFusePass : public ConvActivationFusePass {
67+
public:
68+
std::string activation_type() const { return "hard_swish"; }
69+
};
6370
} // namespace ir
6471
} // namespace framework
6572
} // namespace paddle

paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ TEST(ConvActivationFusePass, conv_leaky_relu_fuse_pass) {
136136
}
137137
TEST(ConvActivationFusePass, conv_relu6_fuse_pass) { MainTest("relu6"); }
138138
TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); }
139+
TEST(ConvActivationFusePass, conv_hard_swish_fuse_pass) {
140+
MainTest("hard_swish");
141+
}
139142

140143
} // namespace ir
141144
} // namespace framework
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
16+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
17+
#include "paddle/fluid/framework/op_version_registry.h"
18+
#include "paddle/fluid/platform/enforce.h"
19+
#include "paddle/fluid/string/pretty_log.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
namespace ir {
24+
25+
using string::PrettyLogDetail;
26+
27+
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
28+
std::vector<std::string> act_types = {"gelu", "tanh", "sigmoid",
29+
"hard_swish"};
30+
31+
for (std::string act_type : act_types) FuseFCAct(graph, act_type);
32+
}
33+
34+
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
35+
const std::string &act_type) const {
36+
PADDLE_ENFORCE_NOT_NULL(
37+
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
38+
FusePassBase::Init("fc_act", graph);
39+
40+
GraphPatternDetector gpd;
41+
patterns::FCActOneDNN fc_act_pattern(gpd.mutable_pattern(), "fc_act");
42+
fc_act_pattern(act_type);
43+
44+
int found_fc_act_count = 0;
45+
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
46+
Graph *g) {
47+
VLOG(4) << "Fuse fc with activation op.";
48+
// FC output
49+
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_act_pattern);
50+
// ACT output
51+
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fc_act_pattern);
52+
// ops
53+
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_act_pattern);
54+
GET_IR_NODE_FROM_SUBGRAPH(act, act, fc_act_pattern);
55+
56+
auto *fc_op = fc->Op();
57+
auto *act_op = act->Op();
58+
59+
if (fc_op->HasAttr("use_mkldnn")) {
60+
PADDLE_ENFORCE(
61+
BOOST_GET_CONST(bool, fc_op->GetAttr("use_mkldnn")),
62+
platform::errors::PreconditionNotMet(
63+
"The FC+Act fusion may happen only when oneDNN library "
64+
"is used."));
65+
}
66+
67+
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
68+
bool approximate = BOOST_GET_CONST(bool, act_op->GetAttr("approximate"));
69+
std::string type = approximate ? "_tanh" : "_erf";
70+
fc_op->SetAttr("activation_type", act_type + type);
71+
} else
72+
fc_op->SetAttr("activation_type", act_type);
73+
74+
fc_op->SetAttr("use_mkldnn", true);
75+
76+
fc_op->SetOutput("Out", {act_out->Name()});
77+
78+
IR_OP_VAR_LINK(fc, act_out);
79+
GraphSafeRemoveNodes(g, {act, fc_out});
80+
found_fc_act_count++;
81+
};
82+
83+
gpd(graph, handler);
84+
AddStatis(found_fc_act_count);
85+
PrettyLogDetail("--- fused %d fc with %s activation", found_fc_act_count,
86+
act_type);
87+
}
88+
89+
} // namespace ir
90+
} // namespace framework
91+
} // namespace paddle
92+
93+
REGISTER_PASS(fc_act_mkldnn_fuse_pass,
94+
paddle::framework::ir::FuseFCActOneDNNPass);
95+
REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
96+
.AddCombination(
97+
paddle::framework::compatible::OpVersionComparatorCombination()
98+
.LE("fc", 0)
99+
.LE("gelu", 0)
100+
.LE("sigmoid", 0)
101+
.LE("hard_swish", 0)
102+
.LE("tanh", 0));
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
19+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
20+
#include "paddle/fluid/framework/ir/graph.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
/*
27+
* \brief Fuse the FC and activation operators into single OneDNN's
28+
* FC with post-op.
29+
*
30+
* \note Currently only GeLU, hardswish, sigmoid and tanh are supported as an
31+
* activation function.
32+
*/
33+
class FuseFCActOneDNNPass : public FusePassBase {
34+
public:
35+
virtual ~FuseFCActOneDNNPass() {}
36+
37+
protected:
38+
void ApplyImpl(ir::Graph *graph) const override;
39+
40+
void FuseFCAct(ir::Graph *graph, const std::string &act_types) const;
41+
};
42+
43+
} // namespace ir
44+
} // namespace framework
45+
} // namespace paddlea

0 commit comments

Comments
 (0)