Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Add transpose_a/b & dynamic shape support for batch matmul #8527

Merged
merged 8 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,16 +1003,26 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
}
};

/*! \brief Attributes for batch matmul operator */
/*! \brief Attributes for batch matmul operator. */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
DataType out_dtype;
bool transpose_a;
bool transpose_b;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite

TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");

TVM_ATTR_FIELD(transpose_a)
.set_default(false)
.describe("Whether the first input tensor is in transposed format.");

TVM_ATTR_FIELD(transpose_b)
.set_default(false)
.describe("Whether the second input tensor is in transposed format.");
}
};

Expand Down
30 changes: 20 additions & 10 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,24 +1276,34 @@ def dense_pack_shape_func(attrs, inputs, _):


@script
def _batch_matmul_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
if i == 0:
out[i] = max(data_shape[i], weight_shape[i])
else:
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2]
def _batch_matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b):
out = output_tensor((tensor_a_shape.shape[0],), "int64")
out[0] = max(tensor_a_shape[0], tensor_b_shape[0])
if transpose_a:
out[1] = tensor_a_shape[2]
else:
out[1] = tensor_a_shape[1]
if transpose_b:
out[2] = tensor_b_shape[1]
else:
out[2] = tensor_b_shape[2]
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

return out


@reg.register_shape_func("nn.batch_matmul", False)
def batch_matmul_shape_func(attrs, inputs, _):
"""
Shape function for dense op.
Shape function for batch matmul op.
"""
ret = [_batch_matmul_shape_func(inputs[0], inputs[1])]
ret = [
_batch_matmul_shape_func(
inputs[0],
inputs[1],
expr.IntImm("bool", attrs.transpose_a),
expr.IntImm("bool", attrs.transpose_b),
)
]
return ret


Expand Down
23 changes: 16 additions & 7 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2137,32 +2137,41 @@ def group_norm(data, gamma, beta, num_groups, axis=1, epsilon=1e-5, center=True,
return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale)


def batch_matmul(x, y, out_dtype=""):
def batch_matmul(tensor_a, tensor_b, out_dtype="", transpose_a=False, transpose_b=True):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
Computes batch matrix multiplication of `A` and `B` when `A` and `B` are data
in batch.

The A & B can be transposed. For legacy reason, we use NT format(tensor_a non-transposed
and tensor_b transposed) by default.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

.. math::

\mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T)
\mbox{batch_matmul}(A, B)[i, :, :] = \mbox{matmul}(A[i, :, :], B[i, :, :])

Parameters
----------
x : tvm.relay.Expr
tensor_a : tvm.relay.Expr
The first input.

y : tvm.relay.Expr
tensor_b : tvm.relay.Expr
The second input.

out_dtype : str, optional
out_dtype : Optional[str]
Specifies the output data type for mixed precision batch matmul

transpose_a : Optional[bool] = False
Whether the data tensor is in transposed format.

transpose_b : Optional[bool] = True
Whether the weight tensor is in transposed format.

Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.batch_matmul(x, y, out_dtype)
return _make.batch_matmul(tensor_a, tensor_b, out_dtype, transpose_a, transpose_b)


# pylint: disable=no-else-return,inconsistent-return-statements
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,10 +799,11 @@ def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, ne

def _compute_batch_matmul(attrs, inputs, out_type):
args = [inputs[0], inputs[1], out_type.shape]
args.append(out_type.dtype if need_out_dtype else None)
args.append(attrs.transpose_a)
args.append(attrs.transpose_b)
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
if need_out_dtype:
args.append(out_type.dtype)
return [topi_compute(*args)]

return _compute_batch_matmul
Expand Down
124 changes: 84 additions & 40 deletions python/tvm/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,73 +21,117 @@
from ..utils import get_const_tuple


def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="", out_dtype=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
def batch_matmul(
tensor_a,
tensor_b,
oshape=None,
out_dtype=None,
transpose_a=False,
transpose_b=True,
auto_scheduler_rewritten_layout="",
):
"""Computes batch matrix multiplication of `A` and `B` when `A` and `B` are
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
data in batch. Supports broadcasting for batch dimension.

The A & B can be transposed. For legacy reason, we use NT format(tensor_a non-transposed
and tensor_b transposed) by default.

Parameters
----------
x : tvm.te.Tensor
3-D with shape [batch, M, K]
tensor_a : tvm.te.Tensor
3-D with shape [batch, M, K] or [batch, K, M]

y : tvm.te.Tensor
3-D with shape [batch, N, K]
tensor_b : tvm.te.Tensor
3-D with shape [batch, K, N] or [batch, N, K]

oshape : List[Optional]
Explicit intended output shape of the computation. Can be useful in cases
with dynamic input shapes.

auto_scheduler_rewritten_layout: str = ""
auto_scheduler_rewritten_layout: Optional[str] = ""
The layout after auto-scheduler's layout rewrite pass.

out_dtype : Optional[str]
Specifies the output data type for mixed precision batch matmul

transpose_a : Optional[bool] = False
Whether the data tensor is in transposed format.

transpose_b : Optional[bool] = True
Whether the weight tensor is in transposed format.

Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
x_shape = get_const_tuple(x.shape)
assert len(tensor_a.shape) == 3, "only support 3-dim batch_matmul"
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
if transpose_a:
XB, XK, XI = get_const_tuple(tensor_a.shape)
else:
XB, XI, XK = get_const_tuple(tensor_a.shape)
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
y_shape = auto_scheduler.get_shape_from_rewritten_layout(
auto_scheduler_rewritten_layout, ["b", "j", "k"]
YB, YK, YJ = auto_scheduler.get_shape_from_rewritten_layout(
auto_scheduler_rewritten_layout, ["b", "k", "j"]
)
auto_scheduler.remove_index_check(y)
auto_scheduler.remove_index_check(tensor_b)
else:
y_shape = get_const_tuple(y.shape)
assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul"
assert len(tensor_b.shape) == 3, "only support 3-dim batch_matmul"
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
if transpose_b:
YB, YJ, YK = get_const_tuple(tensor_b.shape)
else:
YB, YK, YJ = get_const_tuple(tensor_b.shape)

XB = x_shape[0]
YB = y_shape[0]
_, M, K = x.shape
k = te.reduce_axis((0, K), name="k")
assert XK == YK or isinstance(YK, tvm.tir.expr.Var), "shapes of x and y is inconsistent"
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
k = te.reduce_axis((0, XK), name="k")
if oshape is None:
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent"
batch = te.max(XB, YB)
N = y.shape[1]
oshape = (batch, M, N)

if out_dtype is None or out_dtype == x.dtype:
output = te.compute(
oshape,
lambda b, i, j: te.sum(
x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k
),
tag="batch_matmul",
attrs={"layout_free_placeholders": [y]},
batch = (
tvm.tir.Any()
if isinstance(XB, tvm.tir.expr.Var) or isinstance(YB, tvm.tir.expr.Var)
else te.max(XB, YB)
)
else:
output = te.compute(
oshape,
lambda b, i, j: te.sum(
x[b if XB != 1 else 0, i, k].astype(out_dtype)
* y[b if YB != 1 else 0, j, k].astype(out_dtype),
axis=k,
),
tag="batch_matmul",
attrs={"layout_free_placeholders": [y]},
oshape = (batch, XI, YJ)
if out_dtype is None:
out_dtype = tensor_a.dtype
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

if (transpose_a, transpose_b) == (True, True):
compute_lambda = lambda b, i, j: te.sum(
tensor_a[b if XB != 1 else 0, k, i].astype(out_dtype)
* tensor_b[b if YB != 1 else 0, j, k].astype(out_dtype),
axis=k,
)
compute_name = "T_batch_matmul_TT"
elif (transpose_a, transpose_b) == (True, False):
compute_lambda = lambda b, i, j: te.sum(
tensor_a[b if XB != 1 else 0, k, i].astype(out_dtype)
* tensor_b[b if YB != 1 else 0, k, j].astype(out_dtype),
axis=k,
)
compute_name = "T_batch_matmul_TN"
elif (transpose_a, transpose_b) == (False, True):
compute_lambda = lambda b, i, j: te.sum(
tensor_a[b if XB != 1 else 0, i, k].astype(out_dtype)
* tensor_b[b if YB != 1 else 0, j, k].astype(out_dtype),
axis=k,
)
compute_name = "T_batch_matmul_NT"
else: # (transpose_a, transpose_b) == (False, False):
compute_lambda = lambda b, i, j: te.sum(
tensor_a[b if XB != 1 else 0, i, k].astype(out_dtype)
* tensor_b[b if YB != 1 else 0, k, j].astype(out_dtype),
axis=k,
)
compute_name = "T_batch_matmul_NN"

output = te.compute(
oshape,
compute_lambda,
name=compute_name,
tag="batch_matmul",
attrs={"layout_free_placeholders": [tensor_b]},
)
if auto_scheduler_rewritten_layout:
output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout)

Expand Down
17 changes: 13 additions & 4 deletions python/tvm/topi/testing/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def batch_matmul(x, y, out_dtype=None):
def batch_matmul(x, y, out_dtype=None, trans_x=False, trans_y=True):
"""batch_matmul operator implemented in numpy.

Parameters
Expand All @@ -38,13 +38,22 @@ def batch_matmul(x, y, out_dtype=None):
out : numpy.ndarray
3-D with shape [batch, M, N]
"""
XB, M, _ = x.shape
YB, N, _ = y.shape
if trans_x:
XB, _, M = x.shape
else:
XB, M, _ = x.shape
if trans_y:
YB, N, _ = y.shape
else:
YB, _, N = y.shape
batch = max(XB, YB)
dtype = x.dtype if out_dtype is None else out_dtype
out = np.zeros((batch, M, N)).astype(dtype)
for i in range(batch):
xx = x[i if XB != 1 else 0].astype(dtype)
yy = y[i if YB != 1 else 0].astype(dtype)
out[i] = np.dot(
x[i if XB != 1 else 0].astype(dtype), y[i if YB != 1 else 0].T.astype(dtype)
xx.T if trans_x else xx,
yy.T if trans_y else yy,
)
return out
52 changes: 16 additions & 36 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas, mkl
from .. import generic
from .. import generic, nn
from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor


@autotvm.register_topi_compute("batch_matmul.x86")
def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None):
def batch_matmul(
cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
data in batch. Supports broadcasting in batch dimension.

Expand All @@ -45,40 +47,18 @@ def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None):
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
XB, M, XK = get_const_tuple(x.shape)
YB, N, YK = get_const_tuple(y.shape)
assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistent"
B = te.max(XB, YB)
K = XK
if out_shape is not None:
assert out_shape[0] == B, "got invalid output shape"
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
if cfg.is_fallback:
_default_batch_matmul_config(cfg, M, N, K)

k = te.reduce_axis((0, K), name="k")
if out_dtype is None or out_dtype == x.dtype:
C = te.compute(
(B, M, N),
lambda b, i, j: te.sum(
x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k
),
tag="batch_matmul",
)
else:
C = te.compute(
(B, M, N),
lambda b, i, j: te.sum(
x[b if XB != 1 else 0, i, k].astype(out_dtype)
* y[b if YB != 1 else 0, j, k].astype(out_dtype),
axis=k,
),
tag="batch_matmul",
)
return C
if cfg.is_fallback and not transpose_a and transpose_b:
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
B, N, K = get_const_tuple(tensor_a.shape)
_default_batch_matmul_config(cfg, B, N, K)

return nn.batch_matmul(
tensor_a,
tensor_b,
out_shape,
out_dtype,
transpose_a,
transpose_b,
)


@autotvm.register_topi_schedule("batch_matmul.x86")
Expand Down
Loading