Skip to content

BN+ReLU fwd fusion #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions tensorflow/core/graph/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.dequantize = "Dequantize";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
Expand Down Expand Up @@ -296,6 +297,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
"_MklDepthwiseConv2dNativeBackpropInput";
csinfo_.mkl_depthwise_conv2d_grad_filter =
"_MklDepthwiseConv2dNativeBackpropFilter";
csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
Expand Down Expand Up @@ -480,6 +482,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
{csinfo_.fused_batch_norm_grad_v3,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
#ifdef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
csinfo_.mkl_fused_batch_norm_ex, CopyAttrsAll,
FusedBatchNormExRewrite, kRewriteForLayoutPropagation});
#endif
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
CopyAttrsFusedConv2D, FusedConv2DRewrite,
kRewriteForLayoutPropagation});
Expand Down Expand Up @@ -931,6 +938,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string dequantize;
string fused_batch_norm;
string fused_batch_norm_grad;
string fused_batch_norm_ex;
string fused_batch_norm_v2;
string fused_batch_norm_grad_v2;
string fused_batch_norm_v3;
Expand All @@ -957,6 +965,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string mkl_conv2d_with_bias;
string mkl_depthwise_conv2d_grad_input;
string mkl_depthwise_conv2d_grad_filter;
string mkl_fused_batch_norm_ex;
string mkl_fused_conv2d;
string mkl_fused_depthwise_conv2d;
string mkl_fused_matmul;
Expand Down Expand Up @@ -1677,6 +1686,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return do_rewrite;
}

static bool FusedBatchNormExRewrite(const Node* n) {
CHECK_NOTNULL(n);

int num_side_inputs;
TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
string activation_mode;
TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));

// if the num_side_inputs is not 0, don't rewrite the node.
if (num_side_inputs != 0) {
VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
<< "larger than 0 is not optimized by Intel MKL.";
return false;
}

// if the activation_mode is not 'Relu', don't rewrite the node.
if (activation_mode != "Relu") {
VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
<< "supported by Intel MKL.";
return false;
}

return true;
}

static bool FusedConv2DRewrite(const Node* n) {
// MKL DNN currently doesn't support all fusions that grappler fuses
// together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
Expand Down Expand Up @@ -2175,23 +2209,21 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
// Number of input slots to original op
// Input slots are represented by .Input() calls in REGISTER_OP.
int old_node_input_slots = old_node->op_def().input_arg_size();
// Actual number of inputs can be greater than or equal to number
// of Input slots because inputs of type list could be unfolded.
CHECK_GE(old_node_inputs.size(), old_node_input_slots);
int nn_slot_idx = 0; // slot index for inputs of new node

// Let's copy all inputs (TF tensors) of original node to new node.
int iidx = 0;
for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
// An input slot could be a single tensor or a list. We need
// to handle this case accordingly.
CHECK_LT(iidx, old_node_inputs.size());
const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node);
GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
&new_node_inputs);
int tensor_list_length = GetTensorListLength(arg, old_node);
if (tensor_list_length != 0) {
GetNodesProducingTFTensorList(old_node_inputs, &iidx,
tensor_list_length, &new_node_inputs);
}
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
Expand Down Expand Up @@ -2224,13 +2256,14 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
// An input slot could be a single tensor or a list. We need
// to handle this case accordingly.
CHECK_LT(iidx, old_node_inputs.size());
const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node);
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
&new_node_inputs);
int tensor_list_length = GetTensorListLength(arg, old_node);
if (tensor_list_length != 0) {
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
tensor_list_length, &new_node_inputs);
}
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
Expand Down Expand Up @@ -3746,6 +3779,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
n->type_string() != csinfo_.pad_with_conv2d &&
n->type_string() != csinfo_.pad_with_fused_conv2d &&
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
n->type_string() != csinfo_.fused_batch_norm_ex &&
n->type_string() != csinfo_.fused_conv2d &&
n->type_string() != csinfo_.fused_depthwise_conv2d &&
n->type_string() != csinfo_.fused_matmul &&
Expand Down
106 changes: 106 additions & 0 deletions tensorflow/core/graph/mkl_layout_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3240,6 +3240,112 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) {
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
}

#ifdef ENABLE_MKLDNN_V1
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT \
"'}" \
"node { name: 'B' op: 'Input'}" \
"node { name: 'C' op: 'Input'}" \
"node { name: 'D' op: 'Input'}" \
"node { name: 'E' op: 'Input'}" \
"node { name: 'F' op: '_FusedBatchNormEx'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'num_side_inputs' value { i: 0 } }" \
" attr { key: 'is_training' value { b: true } }" \
" attr { key: 'activation_mode' value { s: 'Relu' } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" input: ['A', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT \
");B(Input);C(Input);D(Input);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);" \
"DMT/_4(Const);E(Input);" \
"F(_MklFusedBatchNormEx);G(Zeta)|A->F;A->G;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;A:control->DMT/_3:control;" \
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \
"E->F:4;F->G:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Positive);
#undef REGISTER_TEST

// Rewrite test for _FusedBatchNormEx Op with side input
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT \
"'}" \
"node { name: 'B' op: 'Input'}" \
"node { name: 'C' op: 'Input'}" \
"node { name: 'D' op: 'Input'}" \
"node { name: 'E' op: 'Input'}" \
"node { name: 'F' op: '" #INPUT \
"'}" \
"node { name: 'G' op: '_FusedBatchNormEx'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'num_side_inputs' value { i: 1 } }" \
" attr { key: 'is_training' value { b: true } }" \
" attr { key: 'activation_mode' value { s: 'Relu' } }" \
" input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \
"node { name: 'H' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" input: ['A', 'G'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT \
");B(Input);C(Input);D(Input);E(Input);" \
"F(" #INPUT \
");G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \
"B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative1);
#undef REGISTER_TEST

// Rewrite test for _FusedBatchNormEx Op with Identity activation
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT \
"'}" \
"node { name: 'B' op: 'Input'}" \
"node { name: 'C' op: 'Input'}" \
"node { name: 'D' op: 'Input'}" \
"node { name: 'E' op: 'Input'}" \
"node { name: 'G' op: '_FusedBatchNormEx'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'num_side_inputs' value { i: 1 } }" \
" attr { key: 'is_training' value { b: true } }" \
" attr { key: 'activation_mode' value { s: 'Identity' } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'H' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" input: ['A', 'G'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT \
");B(Input);C(Input);D(Input);E(Input);" \
"G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \
"B->G:1;C->G:2;D->G:3;E->G:4;G->H:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2);
#undef REGISTER_TEST
#endif // ENABLE_MKLDNN_V1

TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
InitGraph(
"node { name: 'A' op: 'QuantizedUnsignedInt8Input'}"
Expand Down
Loading