From 8dac20a91ccb6682cd32a6c4fb1d0a59e64e2fe9 Mon Sep 17 00:00:00 2001 From: jcf94 Date: Sat, 24 Jul 2021 14:51:18 +0800 Subject: [PATCH] Add grad support for batch_matmul --- python/tvm/relay/op/_tensor_grad.py | 56 ++++++++++- python/tvm/relay/op/nn/_nn.py | 4 +- python/tvm/relay/op/op_attrs.py | 5 + python/tvm/relay/op/strategy/cuda.py | 15 ++- python/tvm/topi/cuda/batch_matmul.py | 103 +++++++++++++++++--- python/tvm/topi/cuda/tensorcore_alter_op.py | 10 +- python/tvm/topi/x86/batch_matmul.py | 64 ++++++++---- tests/python/relay/test_op_grad_level10.py | 20 +++- tests/python/relay/test_op_level10.py | 12 +-- 9 files changed, 239 insertions(+), 50 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index fa2772c1299f..3793f947c5cc 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -590,11 +590,59 @@ def batch_matmul_grad(orig, grad): GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk """ lhs, rhs = orig.args + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True): + # ki, jk -> ij + # jk, ij -> ki + # ij, ki -> jk + return [ + collapse_sum_like(_nn.batch_matmul(rhs, grad, transpose_a=True, transpose_b=True), lhs), + collapse_sum_like(_nn.batch_matmul(grad, lhs, transpose_a=True, transpose_b=True), rhs), + ] + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False): + # ki, kj -> ij + # kj, ij -> ki + # ki, ij -> kj + return [ + collapse_sum_like( + _nn.batch_matmul(rhs, grad, transpose_a=False, transpose_b=True), lhs + ), + collapse_sum_like( + _nn.batch_matmul(lhs, grad, transpose_a=False, transpose_b=False), rhs + ), + ] + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True): + # ik, jk -> ij + # ij, jk -> ik + # ij, ik -> jk + # Keep using NT format batch_matmul here for not involving extra ops + # TODO(jcf94): Merge all to normal batch_matmul when it is finally ready + return [ + collapse_sum_like( + _nn.batch_matmul( + grad, + transpose(rhs, [0, 2, 1]), + transpose_a=False, + transpose_b=True, + ), + lhs, + ), + collapse_sum_like( + _nn.batch_matmul( + transpose(grad, [0, 2, 1]), + transpose(lhs, [0, 2, 1]), + transpose_a=False, + transpose_b=True, + ), + rhs, + ), + ] + # (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False) + # ik, kj -> ij + # ij, kj -> ik + # ik, ij -> kj return [ - collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs), - collapse_sum_like( - _nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs - ), + collapse_sum_like(_nn.batch_matmul(grad, rhs, transpose_a=False, transpose_b=True), lhs), + collapse_sum_like(_nn.batch_matmul(lhs, grad, transpose_a=True, transpose_b=False), rhs), ] diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 8382eb0f00b0..96cef8bc3588 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1279,8 +1279,8 @@ def dense_pack_shape_func(attrs, inputs, _): def _batch_matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b): out = output_tensor((tensor_a_shape.shape[0],), "int64") out[0] = max(tensor_a_shape[0], tensor_b_shape[0]) - out[1] = tensor_a_shape[2 if transpose_a else 1] - out[2] = tensor_b_shape[1 if transpose_b else 2] + out[1] = tensor_a_shape[2] if transpose_a else tensor_a_shape[1] + out[2] = tensor_b_shape[1] if transpose_b else tensor_b_shape[2] return out diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 2d185bcee798..507dd9371a97 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -74,6 +74,11 @@ class DenseAttrs(Attrs): """Attributes for nn.dense""" +@tvm._ffi.register_object("relay.attrs.BatchMatmulAttrs") +class BatchMatmulAttrs(Attrs): + """Attributes for nn.batch_matmul""" + + @tvm._ffi.register_object("relay.attrs.SoftmaxAttrs") class SoftmaxAttrs(Attrs): """Attributes for nn.softmax""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 1f999a810164..1dadbebf87a2 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -819,7 +819,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): """batch_matmul cuda strategy""" strategy = _op.OpStrategy() x, y = inputs - if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32": + if ( + x.dtype == "int8" + and y.dtype == "int8" + and out_type.dtype == "int32" + and attrs["transpose_a"] == False + and attrs["transpose_b"] == True + ): strategy.add_implementation( wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True), wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8), @@ -840,7 +846,12 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): name="batch_matmul_cublas.cuda", plevel=15, ) - if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target): + if ( + target.kind.name == "cuda" + and nvcc.have_tensorcore(target=target) + and attrs["transpose_a"] == False + and attrs["transpose_b"] == True + ): x, y = inputs _, M, K = get_const_tuple(x.shape) _, N, K = get_const_tuple(y.shape) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index efd67e0eebe0..3fc8a584b557 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -27,9 +27,49 @@ @autotvm.register_topi_compute("batch_matmul.cuda") -def batch_matmul(cfg, x, y, out_shape=None): - """Compute conv2d with NCHW layout""" - return nn.batch_matmul(x, y) +def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True): + """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. + + Parameters + ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, M, N] + """ + return nn.batch_matmul( + x, + y, + oshape=out_shape, + out_dtype=out_dtype, + transpose_a=transpose_a, + transpose_b=transpose_b, + ) @autotvm.register_topi_schedule("batch_matmul.cuda") @@ -140,20 +180,37 @@ def _callback(op): @autotvm.register_topi_compute("batch_matmul_cublas.cuda") -def batch_matmul_cublas(cfg, x, y, out_shape=None, transpose_a=False, transpose_b=True): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. +def batch_matmul_cublas( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): + """Compute batch matrix multiplication of `x` and `y`. + + Both `x` and `y` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. Parameters ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + x : tvm.te.Tensor - 3-D with shape [batch, M, K] + 3-D with shape [batch, M, K] or [batch, K, M]. y : tvm.te.Tensor - 3-D with shape [batch, N, K] + 3-D with shape [batch, K, N] or [batch, N, K]. - out_shape : None - The output shape + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. Returns ------- @@ -181,7 +238,31 @@ def schedule_batch_matmul_cublas(_, outs): @autotvm.register_topi_compute("batch_matmul_int8.cuda") def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None): - """Batch Matmul operator for int8 on CUDA""" + """Batch Matmul operator for int8 on CUDA. + + Parameters + ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + + x : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + y : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, M, N] + """ if out_dtype is None: out_dtype = x.dtype diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index fffb0d6d48fc..1b0536495290 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -19,7 +19,7 @@ import logging import math -from tvm import relay +from tvm import relay, tir from .. import nn @@ -56,6 +56,14 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): B, M, K = x_tensor.shape B, N, K = y_tensor.shape + if ( + isinstance(B, tir.expr.Any) + or isinstance(M, tir.expr.Any) + or isinstance(K, tir.expr.Any) + or isinstance(N, tir.expr.Any) + ): + # Dynamic shape cannot support alter op layout + return M = M.value K = K.value N = N.value diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index e77768f98507..c032a5c3f536 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -62,10 +62,6 @@ def batch_matmul( output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - if cfg.is_fallback and not transpose_a and transpose_b: - B, N, K = get_const_tuple(tensor_a.shape) - _default_batch_matmul_config(cfg, B, N, K) - return nn.batch_matmul( tensor_a, tensor_b, @@ -145,20 +141,32 @@ def _default_batch_matmul_config(cfg, M, N, K): cfg["tile_y"] = SplitEntity([M // y_bn, y_bn]) -def batch_matmul_blas_common(cfg, x, y, out_shape, lib): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension. +def batch_matmul_blas_common(cfg, tensor_a, tensor_b, out_shape, trans_a, trans_b, lib): + """Computes batch matrix multiplication of `tensor_a` and `tensor_b` when `tensor_a` and + `tensor_b` are data in batch, using one of BLAS libraries. Supports broadcasting in batch + dimension. Parameters ---------- cfg : ConfigSpace Autotvm tuning space config file - x : tvm.te.Tensor - 3-D with shape [batch, M, K] - y : tvm.te.Tensor - 3-D with shape [batch, N, K] - out_shape : tuple or None - Shape of the output + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + trans_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + trans_b : Optional[bool] = True + Whether the second tensor is in transposed format. + lib : A contrib module which implements batch_matmul function cblas and mkl are supported @@ -167,9 +175,15 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - XB, M, XK = get_const_tuple(x.shape) - YB, N, YK = get_const_tuple(y.shape) + assert len(tensor_a.shape) == 3 and len(tensor_b.shape) == 3, "only support 3-dim batch_matmul" + if trans_a: + XB, XK, M = get_const_tuple(tensor_a.shape) + else: + XB, M, XK = get_const_tuple(tensor_a.shape) + if trans_b: + YB, N, YK = get_const_tuple(tensor_b.shape) + else: + YB, YK, N = get_const_tuple(tensor_a.shape) assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistent" if out_shape is not None: @@ -177,13 +191,17 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): assert out_shape[1] == M, "got invalid output shape" assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) - return lib.batch_matmul(x, y, False, True) + return lib.batch_matmul(tensor_a, tensor_b, trans_a, trans_b) @autotvm.register_topi_compute("batch_matmul_cblas.x86") -def batch_matmul_cblas(cfg, x, y, out_shape=None): +def batch_matmul_cblas( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Compute batch_matmul using cblas""" - return batch_matmul_blas_common(cfg, x, y, out_shape, cblas) + return batch_matmul_blas_common( + cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, cblas + ) @autotvm.register_topi_schedule("batch_matmul_cblas.x86") @@ -193,9 +211,13 @@ def schedule_batch_matmul_cblas(_, outs): @autotvm.register_topi_compute("batch_matmul_mkl.x86") -def batch_matmul_mkl(cfg, x, y, out_shape=None): +def batch_matmul_mkl( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Compute batch_matmul using mkl""" - return batch_matmul_blas_common(cfg, x, y, out_shape, mkl) + return batch_matmul_blas_common( + cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, mkl + ) @autotvm.register_topi_schedule("batch_matmul_mkl.x86") diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index e2145f77b366..8d961eb60b18 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -62,10 +62,24 @@ def test_checkpoint(): check_grad(relay.Function(inputs, out_single)) +def verify_batch_matmul_grad(a_shape, b_shape, transpose_a, transpose_b): + tensor_a = relay.var("tensor_a", relay.TensorType(a_shape, "float32")) + tensor_b = relay.var("tensor_b", relay.TensorType(b_shape, "float32")) + check_grad( + relay.Function( + [tensor_a, tensor_b], + relay.op.nn.batch_matmul( + tensor_a, tensor_b, transpose_a=transpose_a, transpose_b=transpose_b + ), + ) + ) + + def test_batch_matmul_grad(): - x = relay.var("x", shape=(2, 3, 5), dtype="float64") - y = relay.var("y", shape=(2, 4, 5), dtype="float64") - check_grad(relay.Function([x, y], relay.op.nn.batch_matmul(x, y))) + verify_batch_matmul_grad((2, 3, 5), (2, 5, 4), False, False) + verify_batch_matmul_grad((2, 3, 5), (2, 4, 5), False, True) + verify_batch_matmul_grad((2, 5, 3), (2, 5, 4), True, False) + verify_batch_matmul_grad((2, 5, 3), (2, 4, 5), True, True) def test_reverse_reshape_grad(): diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 3457e0943685..eda7eac1b025 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -354,12 +354,12 @@ def test_batch_matmul(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((b, m, n), "float32") - verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16)) - verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16)) - verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) - verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) - verify_batch_matmul((1, 32, 16), (1, 16, 32), (1, 16, 16), trans_x=True) - verify_batch_matmul((5, 16, 32), (5, 32, 16), (5, 16, 16), trans_y=False) + verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16), trans_x=False, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16), trans_x=False, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20), trans_x=False, trans_y=True) + verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20), trans_x=False, trans_y=True) + verify_batch_matmul((1, 32, 16), (1, 16, 32), (1, 16, 16), trans_x=True, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 32, 16), (5, 16, 16), trans_x=False, trans_y=False) verify_batch_matmul((5, 32, 16), (5, 32, 20), (5, 16, 20), trans_x=True, trans_y=False)