Skip to content

Commit

Permalink
Add grad support for batch_matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Jul 24, 2021
1 parent 8b2ee8b commit 8dac20a
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 50 deletions.
56 changes: 52 additions & 4 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,11 +590,59 @@ def batch_matmul_grad(orig, grad):
GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk
"""
lhs, rhs = orig.args
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True):
# ki, jk -> ij
# jk, ij -> ki
# ij, ki -> jk
return [
collapse_sum_like(_nn.batch_matmul(rhs, grad, transpose_a=True, transpose_b=True), lhs),
collapse_sum_like(_nn.batch_matmul(grad, lhs, transpose_a=True, transpose_b=True), rhs),
]
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False):
# ki, kj -> ij
# kj, ij -> ki
# ki, ij -> kj
return [
collapse_sum_like(
_nn.batch_matmul(rhs, grad, transpose_a=False, transpose_b=True), lhs
),
collapse_sum_like(
_nn.batch_matmul(lhs, grad, transpose_a=False, transpose_b=False), rhs
),
]
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True):
# ik, jk -> ij
# ij, jk -> ik
# ij, ik -> jk
# Keep using NT format batch_matmul here for not involving extra ops
# TODO(jcf94): Merge all to normal batch_matmul when it is finally ready
return [
collapse_sum_like(
_nn.batch_matmul(
grad,
transpose(rhs, [0, 2, 1]),
transpose_a=False,
transpose_b=True,
),
lhs,
),
collapse_sum_like(
_nn.batch_matmul(
transpose(grad, [0, 2, 1]),
transpose(lhs, [0, 2, 1]),
transpose_a=False,
transpose_b=True,
),
rhs,
),
]
# (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False)
# ik, kj -> ij
# ij, kj -> ik
# ik, ij -> kj
return [
collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs),
collapse_sum_like(
_nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs
),
collapse_sum_like(_nn.batch_matmul(grad, rhs, transpose_a=False, transpose_b=True), lhs),
collapse_sum_like(_nn.batch_matmul(lhs, grad, transpose_a=True, transpose_b=False), rhs),
]


Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,8 +1279,8 @@ def dense_pack_shape_func(attrs, inputs, _):
def _batch_matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b):
out = output_tensor((tensor_a_shape.shape[0],), "int64")
out[0] = max(tensor_a_shape[0], tensor_b_shape[0])
out[1] = tensor_a_shape[2 if transpose_a else 1]
out[2] = tensor_b_shape[1 if transpose_b else 2]
out[1] = tensor_a_shape[2] if transpose_a else tensor_a_shape[1]
out[2] = tensor_b_shape[1] if transpose_b else tensor_b_shape[2]

return out

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class DenseAttrs(Attrs):
"""Attributes for nn.dense"""


@tvm._ffi.register_object("relay.attrs.BatchMatmulAttrs")
class BatchMatmulAttrs(Attrs):
"""Attributes for nn.batch_matmul"""


@tvm._ffi.register_object("relay.attrs.SoftmaxAttrs")
class SoftmaxAttrs(Attrs):
"""Attributes for nn.softmax"""
Expand Down
15 changes: 13 additions & 2 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
"""batch_matmul cuda strategy"""
strategy = _op.OpStrategy()
x, y = inputs
if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32":
if (
x.dtype == "int8"
and y.dtype == "int8"
and out_type.dtype == "int32"
and attrs["transpose_a"] == False
and attrs["transpose_b"] == True
):
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8),
Expand All @@ -840,7 +846,12 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
name="batch_matmul_cublas.cuda",
plevel=15,
)
if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
if (
target.kind.name == "cuda"
and nvcc.have_tensorcore(target=target)
and attrs["transpose_a"] == False
and attrs["transpose_b"] == True
):
x, y = inputs
_, M, K = get_const_tuple(x.shape)
_, N, K = get_const_tuple(y.shape)
Expand Down
103 changes: 92 additions & 11 deletions python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,49 @@


@autotvm.register_topi_compute("batch_matmul.cuda")
def batch_matmul(cfg, x, y, out_shape=None):
"""Compute conv2d with NCHW layout"""
return nn.batch_matmul(x, y)
def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True):
"""Compute batch matrix multiplication of `tensor_a` and `tensor_b`.
Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format
(transpose_a=False, transpose_b=True) by default.
Parameters
----------
cfg : ConfigSpace
Autotvm tuning space config file.
tensor_a : tvm.te.Tensor
3-D with shape [batch, M, K] or [batch, K, M].
tensor_b : tvm.te.Tensor
3-D with shape [batch, K, N] or [batch, N, K].
out_shape : List[Optional]
Explicit intended output shape of the computation. Can be useful in cases
with dynamic input shapes.
out_dtype : Optional[str]
Specifies the output data type for mixed precision batch matmul.
transpose_a : Optional[bool] = False
Whether the first tensor is in transposed format.
transpose_b : Optional[bool] = True
Whether the second tensor is in transposed format.
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
return nn.batch_matmul(
x,
y,
oshape=out_shape,
out_dtype=out_dtype,
transpose_a=transpose_a,
transpose_b=transpose_b,
)


@autotvm.register_topi_schedule("batch_matmul.cuda")
Expand Down Expand Up @@ -140,20 +180,37 @@ def _callback(op):


@autotvm.register_topi_compute("batch_matmul_cublas.cuda")
def batch_matmul_cublas(cfg, x, y, out_shape=None, transpose_a=False, transpose_b=True):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
def batch_matmul_cublas(
cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
):
"""Compute batch matrix multiplication of `x` and `y`.
Both `x` and `y` can be transposed. For legacy reason, we use NT format
(transpose_a=False, transpose_b=True) by default.
Parameters
----------
cfg : ConfigSpace
Autotvm tuning space config file.
x : tvm.te.Tensor
3-D with shape [batch, M, K]
3-D with shape [batch, M, K] or [batch, K, M].
y : tvm.te.Tensor
3-D with shape [batch, N, K]
3-D with shape [batch, K, N] or [batch, N, K].
out_shape : None
The output shape
out_shape : List[Optional]
Explicit intended output shape of the computation. Can be useful in cases
with dynamic input shapes.
out_dtype : Optional[str]
Specifies the output data type for mixed precision batch matmul.
transpose_a : Optional[bool] = False
Whether the first tensor is in transposed format.
transpose_b : Optional[bool] = True
Whether the second tensor is in transposed format.
Returns
-------
Expand Down Expand Up @@ -181,7 +238,31 @@ def schedule_batch_matmul_cublas(_, outs):

@autotvm.register_topi_compute("batch_matmul_int8.cuda")
def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None):
"""Batch Matmul operator for int8 on CUDA"""
"""Batch Matmul operator for int8 on CUDA.
Parameters
----------
cfg : ConfigSpace
Autotvm tuning space config file.
x : tvm.te.Tensor
3-D with shape [batch, M, K] or [batch, K, M].
y : tvm.te.Tensor
3-D with shape [batch, K, N] or [batch, N, K].
out_shape : List[Optional]
Explicit intended output shape of the computation. Can be useful in cases
with dynamic input shapes.
out_dtype : Optional[str]
Specifies the output data type for mixed precision batch matmul.
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
if out_dtype is None:
out_dtype = x.dtype

Expand Down
10 changes: 9 additions & 1 deletion python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import logging
import math
from tvm import relay
from tvm import relay, tir

from .. import nn

Expand Down Expand Up @@ -56,6 +56,14 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):

B, M, K = x_tensor.shape
B, N, K = y_tensor.shape
if (
isinstance(B, tir.expr.Any)
or isinstance(M, tir.expr.Any)
or isinstance(K, tir.expr.Any)
or isinstance(N, tir.expr.Any)
):
# Dynamic shape cannot support alter op layout
return
M = M.value
K = K.value
N = N.value
Expand Down
64 changes: 43 additions & 21 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def batch_matmul(
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
if cfg.is_fallback and not transpose_a and transpose_b:
B, N, K = get_const_tuple(tensor_a.shape)
_default_batch_matmul_config(cfg, B, N, K)

return nn.batch_matmul(
tensor_a,
tensor_b,
Expand Down Expand Up @@ -145,20 +141,32 @@ def _default_batch_matmul_config(cfg, M, N, K):
cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])


def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension.
def batch_matmul_blas_common(cfg, tensor_a, tensor_b, out_shape, trans_a, trans_b, lib):
"""Computes batch matrix multiplication of `tensor_a` and `tensor_b` when `tensor_a` and
`tensor_b` are data in batch, using one of BLAS libraries. Supports broadcasting in batch
dimension.
Parameters
----------
cfg : ConfigSpace
Autotvm tuning space config file
x : tvm.te.Tensor
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the output
tensor_a : tvm.te.Tensor
3-D with shape [batch, M, K] or [batch, K, M].
tensor_b : tvm.te.Tensor
3-D with shape [batch, K, N] or [batch, N, K].
out_shape : List[Optional]
Explicit intended output shape of the computation. Can be useful in cases
with dynamic input shapes.
trans_a : Optional[bool] = False
Whether the first tensor is in transposed format.
trans_b : Optional[bool] = True
Whether the second tensor is in transposed format.
lib : A contrib module which implements batch_matmul function
cblas and mkl are supported
Expand All @@ -167,23 +175,33 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
XB, M, XK = get_const_tuple(x.shape)
YB, N, YK = get_const_tuple(y.shape)
assert len(tensor_a.shape) == 3 and len(tensor_b.shape) == 3, "only support 3-dim batch_matmul"
if trans_a:
XB, XK, M = get_const_tuple(tensor_a.shape)
else:
XB, M, XK = get_const_tuple(tensor_a.shape)
if trans_b:
YB, N, YK = get_const_tuple(tensor_b.shape)
else:
YB, YK, N = get_const_tuple(tensor_a.shape)
assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistent"
if out_shape is not None:
assert out_shape[0] in (XB, YB), "got invalid output shape"
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
cfg.add_flop(XB * M * N * XK * 2)
return lib.batch_matmul(x, y, False, True)
return lib.batch_matmul(tensor_a, tensor_b, trans_a, trans_b)


@autotvm.register_topi_compute("batch_matmul_cblas.x86")
def batch_matmul_cblas(cfg, x, y, out_shape=None):
def batch_matmul_cblas(
cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
):
"""Compute batch_matmul using cblas"""
return batch_matmul_blas_common(cfg, x, y, out_shape, cblas)
return batch_matmul_blas_common(
cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, cblas
)


@autotvm.register_topi_schedule("batch_matmul_cblas.x86")
Expand All @@ -193,9 +211,13 @@ def schedule_batch_matmul_cblas(_, outs):


@autotvm.register_topi_compute("batch_matmul_mkl.x86")
def batch_matmul_mkl(cfg, x, y, out_shape=None):
def batch_matmul_mkl(
cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
):
"""Compute batch_matmul using mkl"""
return batch_matmul_blas_common(cfg, x, y, out_shape, mkl)
return batch_matmul_blas_common(
cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, mkl
)


@autotvm.register_topi_schedule("batch_matmul_mkl.x86")
Expand Down
Loading

0 comments on commit 8dac20a

Please sign in to comment.