Skip to content

Commit

Permalink
[Refactor] Clean up duplicated code in matmul_op/batch_matmul_op and …
Browse files Browse the repository at this point in the history
…associated tests.
  • Loading branch information
liutongxuan committed Aug 31, 2022
1 parent b7507f8 commit 2463392
Show file tree
Hide file tree
Showing 16 changed files with 351 additions and 985 deletions.
2 changes: 1 addition & 1 deletion tensorflow/c/eager/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
EXPECT_NE(TF_OK, TF_GetCode(status));
EXPECT_EQ(nullptr, t);
const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]";
EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
<< TF_Message(status);
// Since error is not cleared, the following copy with correct device will
Expand Down
7 changes: 0 additions & 7 deletions tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
Expand All @@ -915,7 +914,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
Expand All @@ -934,7 +932,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
Expand All @@ -953,7 +950,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
Expand Down Expand Up @@ -991,7 +987,6 @@ TEST(SessionTest, ExtendValidation) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);
Expand All @@ -1008,7 +1003,6 @@ TEST(SessionTest, ExtendValidation) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);
Expand All @@ -1022,7 +1016,6 @@ TEST(SessionTest, ExtendValidation) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/framework/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
TEST(TFunc, WXPlusB) {
auto expect = R"P(
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
mm = MatMul[T=$T, transpose_a=false, transpose_b=false](w, x)
y = Add[T=$T](mm:product:0, b)
return y = y:z:0
}
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/core/framework/function_testlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,7 @@ FunctionDef WXPlusB() {
{"w", "x"},
{{"T", "$T"},
{"transpose_a", false},
{"transpose_b", false},
{"_kernel", "eigen"}}},
{"transpose_b", false}}},
{{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
}

Expand Down
79 changes: 22 additions & 57 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ load(
)
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl_ml",
"mkl_deps",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
Expand Down Expand Up @@ -3962,7 +3961,6 @@ cc_library(
deps = [
":aggregate_ops",
":argmax_op",
":batch_matmul_op",
":betainc_op",
":bincount_op",
":bucketize_op",
Expand Down Expand Up @@ -3999,15 +3997,27 @@ tf_kernel_library(

tf_kernel_library(
name = "batch_matmul_op",
deps = [":matmul_op"],
)

tf_kernel_library(
name = "matmul_op",
# <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
hdrs = ["batch_matmul_op_impl.h"],
prefix = "batch_matmul_op",
deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
"//third_party/mkl:intel_binary_blob",
]) + if_cuda([
"//tensorflow/core/kernels:gpu_utils",
"//tensorflow/core/platform:tensor_float_32_utils",
]),
hdrs = ["matmul_op_impl.h"],
defines = select({
":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
"//conditions:default": [],
}),
prefix = "matmul_op",
deps = MATH_DEPS + [
":eigen_contraction_kernel",
":fused_eigen_output_kernels",
] + select({
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
"//conditions:default": [],
}) + mkl_deps() + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]) + if_cuda_or_rocm([":gpu_utils"]),
)

tf_mkl_kernel_library(
Expand Down Expand Up @@ -4110,30 +4120,6 @@ tf_kernel_library(
]
)

tf_kernel_library(
name = "matmul_op",
srcs = [
"matmul_op.cc",
"matmul_op_fused.cc",
],
hdrs = ["matmul_op.h"],
defines = select({
":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
"//conditions:default": [],
}),
deps = MATH_DEPS + [
":eigen_contraction_kernel",
":fused_eigen_output_kernels",
":ops_util",
":gpu_utils",
] + select({
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
"//conditions:default": [],
}) + mkl_deps() + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]),
)

tf_mkl_kernel_library(
name = "mkl_matmul_op",
srcs = [
Expand Down Expand Up @@ -4397,25 +4383,6 @@ tf_cuda_cc_test(
],
)

tf_cuda_cc_test(
name = "batch_matmul_op_test",
size = "small",
srcs = ["batch_matmul_op_test.cc"],
deps = [
":batch_matmul_op",
":broadcast_to_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

tf_cuda_cc_test(
name = "scan_ops_test",
size = "small",
Expand Down Expand Up @@ -6783,8 +6750,8 @@ filegroup(
"identity_op.h",
"immutable_constant_op.cc",
"immutable_constant_op.h",
"matmul_op.cc",
"matmul_op.h",
"matmul_op_impl.h",
"matmul_op_real.cc",
"no_op.cc",
"no_op.h",
"non_max_suppression_op.cc",
Expand Down Expand Up @@ -6860,7 +6827,6 @@ filegroup(
srcs = [
"argmax_op.h",
"avgpooling_op.h",
"batch_matmul_op_impl.h",
"batch_norm_op.h",
"control_flow_ops.h",
"conv_2d.h",
Expand Down Expand Up @@ -6931,7 +6897,6 @@ filegroup(
srcs = [
"argmax_op.cc",
"avgpooling_op.cc",
"batch_matmul_op_real.cc",
"batch_norm_op.cc",
"bcast_ops.cc",
"check_numerics_op.cc",
Expand Down
Loading

0 comments on commit 2463392

Please sign in to comment.