Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Add transpose_a/b & dynamic shape support for batch matmul #8527

Merged
merged 8 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,16 +1003,26 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
}
};

/*! \brief Attributes for batch matmul operator */
/*! \brief Attributes for batch matmul operator. */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
DataType out_dtype;
bool transpose_a;
bool transpose_b;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite

TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");

TVM_ATTR_FIELD(transpose_a)
.set_default(false)
.describe("Whether the first input tensor is in transposed format.");

TVM_ATTR_FIELD(transpose_b)
.set_default(false)
.describe("Whether the second input tensor is in transposed format.");
}
};

Expand Down
23 changes: 17 additions & 6 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@
# However, please note that `nn.matmul` is in experimental so it may have some performance
# issues.
"use_dense": True,
# By default, TVM converts `tf.batch_matmul` to `transpose(weight) + nn.batch_matmul_NT`.
# Change this flag to False to directly convert to `nn.batch_matmul`.
# Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some
# performance issues.
"use_nt_batch_matmul": True,
}

# compatible operators that do NOT require any conversion.
Expand Down Expand Up @@ -1214,7 +1219,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
return func, self._params


def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op=True):
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, convert_config=None):
"""Load tensorflow graph which is a python tensorflow graph object into relay.
The companion parameters will be handled automatically.

Expand All @@ -1232,10 +1237,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op
outputs : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.

use_dense_op : bool (Optional) = True
Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
The `nn.dense` op requires the data tensor to be non-transposed and weight tensor to be
transposed, may insert extra `transpose` to the original graph.
convert_config : Optional[Dict[str, Any]]
Default config:
use_dense : bool = True
Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
The `nn.dense` op requires the data tensor to be non-transposed and weight tensor
to be transposed, may insert extra `transpose` to the original graph.
use_nt_batch_matmul : bool = True
True to convert `tf.batch_matmul` to `nn.batch_matmul` strict to NT format
(transpose_a=False, transpose_b=True).

Returns
-------
Expand All @@ -1246,7 +1256,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op
Dict of converted parameters stored in tvm.nd.NDArray format
"""
global TF_DEFAULT_CONFIGS
TF_DEFAULT_CONFIGS["use_dense"] = use_dense_op
if convert_config is not None:
TF_DEFAULT_CONFIGS.update(convert_config)

g = GraphProto()
mod, params = g.from_tensorflow(graph, layout, shape, outputs)
Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,8 @@ def _impl(inputs, attr, params, mod):

def _batch_matmul():
def _impl(inputs, attr, params, mod):
from .tensorflow import TF_DEFAULT_CONFIGS

input_x = inputs[0]
input_y = inputs[1]
orig_shape_x = _infer_shape(input_x, mod)
Expand Down Expand Up @@ -1173,9 +1175,16 @@ def _impl(inputs, attr, params, mod):
input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1]))
adj_x = attr["adj_x"]
adj_y = attr["adj_y"]
input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y
ret = get_relay_op("batch_matmul")(input_x, input_y)

if TF_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
# Strictly convert all batch_matmul to NT format
input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y
ret = get_relay_op("batch_matmul")(input_x, input_y)
else:
ret = get_relay_op("batch_matmul")(
input_x, input_y, transpose_a=adj_x, transpose_b=adj_y
)

# reshape result back to n-dimensional
if ndim > 3:
Expand Down
56 changes: 52 additions & 4 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,11 +590,59 @@ def batch_matmul_grad(orig, grad):
GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk
"""
lhs, rhs = orig.args
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True):
# ki, jk -> ij
# jk, ij -> ki
# ij, ki -> jk
return [
collapse_sum_like(_nn.batch_matmul(rhs, grad, transpose_a=True, transpose_b=True), lhs),
collapse_sum_like(_nn.batch_matmul(grad, lhs, transpose_a=True, transpose_b=True), rhs),
]
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False):
# ki, kj -> ij
# kj, ij -> ki
# ki, ij -> kj
return [
collapse_sum_like(
_nn.batch_matmul(rhs, grad, transpose_a=False, transpose_b=True), lhs
),
collapse_sum_like(
_nn.batch_matmul(lhs, grad, transpose_a=False, transpose_b=False), rhs
),
]
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True):
# ik, jk -> ij
# ij, jk -> ik
# ij, ik -> jk
# Keep using NT format batch_matmul here for not involving extra ops
# TODO(jcf94): Merge all to normal batch_matmul when it is finally ready
return [
collapse_sum_like(
_nn.batch_matmul(
grad,
transpose(rhs, [0, 2, 1]),
transpose_a=False,
transpose_b=True,
),
lhs,
),
collapse_sum_like(
_nn.batch_matmul(
transpose(grad, [0, 2, 1]),
transpose(lhs, [0, 2, 1]),
transpose_a=False,
transpose_b=True,
),
rhs,
),
]
# (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False)
# ik, kj -> ij
# ij, kj -> ik
# ik, ij -> kj
return [
collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs),
collapse_sum_like(
_nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs
),
collapse_sum_like(_nn.batch_matmul(grad, rhs, transpose_a=False, transpose_b=True), lhs),
collapse_sum_like(_nn.batch_matmul(lhs, grad, transpose_a=True, transpose_b=False), rhs),
]


Expand Down
24 changes: 14 additions & 10 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,24 +1276,28 @@ def dense_pack_shape_func(attrs, inputs, _):


@script
def _batch_matmul_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
if i == 0:
out[i] = max(data_shape[i], weight_shape[i])
else:
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2]
def _batch_matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b):
out = output_tensor((tensor_a_shape.shape[0],), "int64")
out[0] = max(tensor_a_shape[0], tensor_b_shape[0])
out[1] = tensor_a_shape[2] if transpose_a else tensor_a_shape[1]
out[2] = tensor_b_shape[1] if transpose_b else tensor_b_shape[2]

return out


@reg.register_shape_func("nn.batch_matmul", False)
def batch_matmul_shape_func(attrs, inputs, _):
"""
Shape function for dense op.
Shape function for batch matmul op.
"""
ret = [_batch_matmul_shape_func(inputs[0], inputs[1])]
ret = [
_batch_matmul_shape_func(
inputs[0],
inputs[1],
expr.IntImm("bool", attrs.transpose_a),
expr.IntImm("bool", attrs.transpose_b),
)
]
return ret


Expand Down
26 changes: 17 additions & 9 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2137,32 +2137,40 @@ def group_norm(data, gamma, beta, num_groups, axis=1, epsilon=1e-5, center=True,
return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale)


def batch_matmul(x, y, out_dtype=""):
def batch_matmul(tensor_a, tensor_b, out_dtype="", transpose_a=False, transpose_b=True):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
in batch.
Compute batch matrix multiplication of `tensor_a` and `tensor_b`.

Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format
(transpose_a=False, transpose_b=True) by default.

.. math::

\mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T)
\mbox{batch_matmul}(A, B)[i, :, :] = \mbox{matmul}(A[i, :, :], B[i, :, :])

Parameters
----------
x : tvm.relay.Expr
tensor_a : tvm.relay.Expr
The first input.

y : tvm.relay.Expr
tensor_b : tvm.relay.Expr
The second input.

out_dtype : str, optional
Specifies the output data type for mixed precision batch matmul
out_dtype : Optional[str]
Specifies the output data type for mixed precision batch matmul.

transpose_a : Optional[bool] = False
Whether the first tensor is in transposed format.

transpose_b : Optional[bool] = True
Whether the second tensor is in transposed format.

Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.batch_matmul(x, y, out_dtype)
return _make.batch_matmul(tensor_a, tensor_b, out_dtype, transpose_a, transpose_b)


# pylint: disable=no-else-return,inconsistent-return-statements
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class DenseAttrs(Attrs):
"""Attributes for nn.dense"""


@tvm._ffi.register_object("relay.attrs.BatchMatmulAttrs")
class BatchMatmulAttrs(Attrs):
"""Attributes for nn.batch_matmul"""


@tvm._ffi.register_object("relay.attrs.SoftmaxAttrs")
class SoftmaxAttrs(Attrs):
"""Attributes for nn.softmax"""
Expand Down
15 changes: 13 additions & 2 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
"""batch_matmul cuda strategy"""
strategy = _op.OpStrategy()
x, y = inputs
if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32":
if (
x.dtype == "int8"
and y.dtype == "int8"
and out_type.dtype == "int32"
and not attrs["transpose_a"]
and attrs["transpose_b"]
):
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8),
Expand All @@ -840,7 +846,12 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
name="batch_matmul_cublas.cuda",
plevel=15,
)
if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
if (
target.kind.name == "cuda"
and nvcc.have_tensorcore(target=target)
and not attrs["transpose_a"]
and attrs["transpose_b"]
):
x, y = inputs
_, M, K = get_const_tuple(x.shape)
_, N, K = get_const_tuple(y.shape)
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,10 +799,11 @@ def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, ne

def _compute_batch_matmul(attrs, inputs, out_type):
args = [inputs[0], inputs[1], out_type.shape]
args.append(out_type.dtype if need_out_dtype else None)
args.append(attrs.transpose_a)
args.append(attrs.transpose_b)
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
if need_out_dtype:
args.append(out_type.dtype)
return [topi_compute(*args)]

return _compute_batch_matmul
Expand Down
Loading