Skip to content

Commit a5b7556

Browse files
authored
Merge pull request #13 from Intel-tensorflow/guizili/matmul_fusion
Add matmul_grad_filter and biasaddgrad fusion
2 parents 69ccb40 + 9bcff32 commit a5b7556

File tree

10 files changed

+1029
-2
lines changed

10 files changed

+1029
-2
lines changed

tensorflow/core/framework/common_shape_fns.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,36 @@ Status MatMulShape(shape_inference::InferenceContext* c) {
130130
return Status::OK();
131131
}
132132

133+
Status MatMulGradFilterShape(shape_inference::InferenceContext* c) {
134+
ShapeHandle a;
135+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
136+
137+
ShapeHandle b;
138+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
139+
140+
bool transpose_a, transpose_b;
141+
TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
142+
TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
143+
DimensionHandle output_rows = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
144+
DimensionHandle output_cols = c->Dim(b, 1);
145+
146+
if (transpose_b) {
147+
auto tmp = output_rows;
148+
output_rows = output_cols;
149+
output_cols = tmp;
150+
}
151+
152+
// Validate that the inner shapes are compatible.
153+
DimensionHandle inner_a = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
154+
DimensionHandle inner_b = c->Dim(b, 0);
155+
DimensionHandle merged;
156+
TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
157+
158+
c->set_output(0, c->Matrix(output_rows, output_cols));
159+
c->set_output(1, c->Vector(output_cols));
160+
return Status::OK();
161+
}
162+
133163
namespace {
134164

135165
// Validate that an Einsum subscript contains exactly one or zero ellipsis; and

tensorflow/core/framework/common_shape_fns.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
102102

103103
// Shape function for MatMul-like operations.
104104
Status MatMulShape(shape_inference::InferenceContext* c);
105+
Status MatMulGradFilterShape(shape_inference::InferenceContext* c);
105106

106107
// Shape function for Batched MatMul-like operations with broadcasting across
107108
// batch dimensions.

tensorflow/core/graph/mkl_layout_pass.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
275275
csinfo_.fused_conv2d = "_FusedConv2D";
276276
csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
277277
csinfo_.fused_matmul = "_FusedMatMul";
278+
csinfo_.fused_matmul_grad = "_FusedMatMulGrad";
278279
csinfo_.identity = "Identity";
279280
csinfo_.leakyrelu = "LeakyRelu";
280281
csinfo_.leakyrelu_grad = "LeakyReluGrad";
@@ -298,6 +299,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
298299
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
299300
csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
300301
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
302+
csinfo_.mkl_fused_matmul_grad = "_MklFusedMatMulGrad";
301303
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
302304
csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
303305
csinfo_.pad = "Pad";
@@ -487,6 +489,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
487489
kRewriteForLayoutPropagation});
488490
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
489491
CopyAttrsAllCheckConstFilter, FusedMatMulRewrite});
492+
rinfo_.push_back({csinfo_.fused_matmul_grad, csinfo_.mkl_fused_matmul_grad,
493+
CopyAttrsAll, AlwaysRewrite,
494+
kRewriteForLayoutPropagation});
490495

491496
rinfo_.push_back({csinfo_.identity,
492497
mkl_op_registry::GetMklOpName(csinfo_.identity),
@@ -933,6 +938,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
933938
string fused_conv2d;
934939
string fused_depthwise_conv2d;
935940
string fused_matmul;
941+
string fused_matmul_grad;
936942
string identity;
937943
string leakyrelu;
938944
string leakyrelu_grad;
@@ -954,6 +960,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
954960
string mkl_fused_conv2d;
955961
string mkl_fused_depthwise_conv2d;
956962
string mkl_fused_matmul;
963+
string mkl_fused_matmul_grad;
957964
string mkl_pad_with_conv2d;
958965
string mkl_pad_with_fused_conv2d;
959966
string mul;
@@ -3742,6 +3749,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
37423749
n->type_string() != csinfo_.fused_conv2d &&
37433750
n->type_string() != csinfo_.fused_depthwise_conv2d &&
37443751
n->type_string() != csinfo_.fused_matmul &&
3752+
n->type_string() != csinfo_.fused_matmul_grad &&
37453753
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
37463754
T)) {
37473755
return nullptr;

tensorflow/core/graph/mkl_layout_pass_test.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2009,6 +2009,30 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Positive)
20092009
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Negative);
20102010
#undef REGISTER_TEST
20112011

2012+
// Test set: _FusedMatMulGrad -> MklFusedMatMulGrad rewrite tests
2013+
#define REGISTER_TEST(NAME, T, INPUT) \
2014+
TEST_F(MklLayoutPassTest, NAME##_##T) { \
2015+
InitGraph( \
2016+
"node { name: 'A' op: '" #INPUT "'}" \
2017+
"node { name: 'B' op: '" #INPUT "'}" \
2018+
"node { name: 'D' op: '_FusedMatMulGrad'" \
2019+
" attr { key: 'T' value { type:" #T "} }" \
2020+
" attr { key: 'transpose_a' value { b: false } }" \
2021+
" attr { key: 'transpose_b' value { b: false } }" \
2022+
" attr { key: 'fused_ops' value { list: {s: 'BiasAddGrad'} } }" \
2023+
" input: ['A', 'B']}" \
2024+
"node { name: 'Z' op: 'Zeta'" \
2025+
" attr {key: 'T' value { type: " #T " } }" \
2026+
" input: ['D']}"); \
2027+
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
2028+
"A(" #INPUT ");B(" #INPUT ");D(_MklFusedMatMulGrad);" \
2029+
"DMT/_0(Const);DMT/_1(Const);Z(Zeta)" \
2030+
"|A->D;A:control->DMT/_0:control;A:control->DMT/_1:control;" \
2031+
"B->D:1;D->Z;DMT/_0->D:2;DMT/_1->D:3"); \
2032+
}
2033+
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMulGrad_Positive);
2034+
#undef REGISTER_TEST
2035+
20122036
// Merge test for PadWithFusedConv2D Op with BiasAdd fusion
20132037
// padding is VALID type
20142038
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,

0 commit comments

Comments
 (0)