Skip to content

Commit 8287386

Browse files
authored
Revert "Fuse BatchMatMul with Mul" (#26)
1 parent f1d82fc commit 8287386

File tree

7 files changed

+33
-463
lines changed

7 files changed

+33
-463
lines changed

tensorflow/core/graph/mkl_layout_pass.cc

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
277277
csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
278278
csinfo_.fused_matmul = "_FusedMatMul";
279279
csinfo_.fused_matmul_grad = "_FusedMatMulGrad";
280-
csinfo_.fused_batch_matmul = "_FusedBatchMatMul";
281-
csinfo_.fused_batch_matmul_v2 = "_FusedBatchMatMulV2";
282280
csinfo_.gelu = "Gelu";
283281
csinfo_.gelu_grad = "GeluGrad";
284282
csinfo_.identity = "Identity";
@@ -306,8 +304,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
306304
csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
307305
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
308306
csinfo_.mkl_fused_matmul_grad = "_MklFusedMatMulGrad";
309-
csinfo_.mkl_fused_batch_matmul = "_MklFusedBatchMatMul";
310-
csinfo_.mkl_fused_batch_matmul_v2 = "_MklFusedBatchMatMulV2";
311307
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
312308
csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
313309
csinfo_.pad = "Pad";
@@ -506,12 +502,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
506502
CopyAttrsAll, AlwaysRewrite,
507503
kRewriteForLayoutPropagation});
508504

509-
rinfo_.push_back({csinfo_.fused_batch_matmul,
510-
csinfo_.mkl_fused_batch_matmul, CopyAttrsAll,
511-
AlwaysRewrite, kRewriteForOpNameChange});
512-
rinfo_.push_back({csinfo_.fused_batch_matmul_v2,
513-
csinfo_.mkl_fused_batch_matmul_v2, CopyAttrsAll,
514-
AlwaysRewrite, kRewriteForOpNameChange});
515505
rinfo_.push_back({csinfo_.gelu, mkl_op_registry::GetMklOpName(csinfo_.gelu),
516506
CopyAttrsAll, GeluRewrite, kRewriteForLayoutPropagation});
517507
rinfo_.push_back({csinfo_.gelu_grad,
@@ -964,8 +954,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
964954
string fused_depthwise_conv2d;
965955
string fused_matmul;
966956
string fused_matmul_grad;
967-
string fused_batch_matmul;
968-
string fused_batch_matmul_v2;
969957
string gelu;
970958
string gelu_grad;
971959
string identity;
@@ -991,8 +979,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
991979
string mkl_fused_depthwise_conv2d;
992980
string mkl_fused_matmul;
993981
string mkl_fused_matmul_grad;
994-
string mkl_fused_batch_matmul;
995-
string mkl_fused_batch_matmul_v2;
996982
string mkl_pad_with_conv2d;
997983
string mkl_pad_with_fused_conv2d;
998984
string mul;
@@ -3828,8 +3814,6 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
38283814
n->type_string() != csinfo_.fused_depthwise_conv2d &&
38293815
n->type_string() != csinfo_.fused_matmul &&
38303816
n->type_string() != csinfo_.fused_matmul_grad &&
3831-
n->type_string() != csinfo_.fused_batch_matmul &&
3832-
n->type_string() != csinfo_.fused_batch_matmul_v2 &&
38333817
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
38343818
T)) {
38353819
return nullptr;

tensorflow/core/graph/mkl_layout_pass_test.cc

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,31 +2033,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Negative);
20332033
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMulGrad_Positive);
20342034
#undef REGISTER_TEST
20352035

2036-
// Test set: _FusedBatchMatMul -> MklFusedBatchMatMul rewrite tests
2037-
#define REGISTER_TEST(NAME, T, INPUT) \
2038-
TEST_F(MklLayoutPassTest, NAME##_##T) { \
2039-
InitGraph( \
2040-
"node { name: 'A' op: '" #INPUT "'}" \
2041-
"node { name: 'B' op: '" #INPUT "'}" \
2042-
"node { name: 'C' op: '" #INPUT "'}" \
2043-
"node { name: 'D' op: '_FusedBatchMatMul'" \
2044-
" attr { key: 'T' value { type:" #T "} }" \
2045-
" attr { key: 'adj_x' value { b: false } }" \
2046-
" attr { key: 'adj_y' value { b: false } }" \
2047-
" attr { key: 'num_args' value { i: 1 } }" \
2048-
" attr { key: 'fused_ops' value { list: {s: 'Scale'} } }" \
2049-
" input: ['A', 'B', 'C']}" \
2050-
"node { name: 'Z' op: 'Zeta'" \
2051-
" attr {key: 'T' value { type: " #T " } }" \
2052-
" input: ['D', 'C']}"); \
2053-
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
2054-
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
2055-
"D(_MklFusedBatchMatMul);Z(Zeta)" \
2056-
"|A->D;B->D:1;C->D:2;C->Z:1;D->Z"); \
2057-
}
2058-
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchMatMul_Positive)
2059-
#undef REGISTER_TEST
2060-
20612036
// Merge test for PadWithFusedConv2D Op with BiasAdd fusion
20622037
// padding is VALID type
20632038
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,

tensorflow/core/grappler/optimizers/mkl_remapper_test.cc

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -765,99 +765,6 @@ TEST_F(MklRemapperTest, FuseBatchNormWithAddAndRelu) {
765765
EXPECT_EQ(found, 3);
766766
}
767767
}
768-
769-
class MklFuseBatchMatMulWithMul : public MklRemapperTest {
770-
public:
771-
void VerifyFused(bool adjx, bool adjy) {
772-
using ::tensorflow::ops::Placeholder;
773-
int b = 2;
774-
int m = 2;
775-
int k = 3;
776-
int n = 4;
777-
778-
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
779-
780-
auto input_shape = ops::Placeholder::Shape({b, m, k});
781-
if (adjx) input_shape = ops::Placeholder::Shape({b, k, m});
782-
auto weight_shape = ops::Placeholder::Shape({b, k, n});
783-
if (adjy) weight_shape = ops::Placeholder::Shape({b, n, k});
784-
785-
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
786-
auto weight = Placeholder(s.WithOpName("weight"), DT_FLOAT, weight_shape);
787-
788-
auto batchmatmul =
789-
ops::BatchMatMulV2(s.WithOpName("batchmatmul"), input, weight,
790-
ops::BatchMatMulV2::Attrs().AdjX(adjx).AdjY(adjy));
791-
auto scale = ops::Const(s.WithOpName("scale"), {10.0f});
792-
auto mul = ops::Multiply(s.WithOpName("mul"), batchmatmul, scale);
793-
794-
auto fetch_mul = ops::Identity(s.WithOpName("fetch_mul"), mul);
795-
796-
auto input_t = GenerateRandomTensor<DT_FLOAT>({b, m, k});
797-
if (adjx) input_t = GenerateRandomTensor<DT_FLOAT>({b, k, m});
798-
auto weight_t = GenerateRandomTensor<DT_FLOAT>({b, k, n});
799-
if (adjy) weight_t = GenerateRandomTensor<DT_FLOAT>({b, n, k});
800-
801-
GrapplerItem item;
802-
item.fetch = {"fetch_mul"};
803-
item.feed = {{"input", input_t}, {"weight", weight_t}};
804-
TF_CHECK_OK(s.ToGraphDef(&item.graph));
805-
806-
// Place all nodes on CPU.
807-
for (int i = 0; i < item.graph.node_size(); ++i) {
808-
item.graph.mutable_node(i)->set_device("/device:CPU:0");
809-
}
810-
811-
Remapper optimizer(RewriterConfig::ON);
812-
GraphDef output;
813-
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
814-
815-
int found = 0;
816-
for (const NodeDef& node : output.node()) {
817-
if (node.name() == "mul") {
818-
EXPECT_EQ("_FusedBatchMatMulV2", node.op());
819-
EXPECT_EQ("input", node.input(0));
820-
EXPECT_EQ("weight", node.input(1));
821-
EXPECT_EQ("scale", node.input(2));
822-
823-
const auto fused_ops = node.attr().at("fused_ops").list().s();
824-
EXPECT_EQ(1, fused_ops.size());
825-
EXPECT_EQ("Mul", fused_ops[0]);
826-
found++;
827-
}
828-
}
829-
EXPECT_EQ(1, found);
830-
831-
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
832-
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
833-
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
834-
}
835-
};
836-
837-
TEST_F(MklFuseBatchMatMulWithMul, a0b0) {
838-
bool adjx = false;
839-
bool adjy = false;
840-
this->VerifyFused(adjx, adjy);
841-
}
842-
843-
TEST_F(MklFuseBatchMatMulWithMul, a1b0) {
844-
bool adjx = true;
845-
bool adjy = false;
846-
this->VerifyFused(adjx, adjy);
847-
}
848-
849-
TEST_F(MklFuseBatchMatMulWithMul, a0b1) {
850-
bool adjx = false;
851-
bool adjy = true;
852-
this->VerifyFused(adjx, adjy);
853-
}
854-
855-
TEST_F(MklFuseBatchMatMulWithMul, a1b1) {
856-
bool adjx = true;
857-
bool adjy = true;
858-
this->VerifyFused(adjx, adjy);
859-
}
860-
861768
#endif // ENABLE_MKLDNN_V1
862769

863770
} // namespace grappler

0 commit comments

Comments
 (0)