Skip to content

Commit 8dac20a

Browse files
committed
Add grad support for batch_matmul
1 parent 8b2ee8b commit 8dac20a

File tree

9 files changed

+239
-50
lines changed

9 files changed

+239
-50
lines changed

python/tvm/relay/op/_tensor_grad.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -590,11 +590,59 @@ def batch_matmul_grad(orig, grad):
590590
GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk
591591
"""
592592
lhs, rhs = orig.args
593+
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True):
594+
# ki, jk -> ij
595+
# jk, ij -> ki
596+
# ij, ki -> jk
597+
return [
598+
collapse_sum_like(_nn.batch_matmul(rhs, grad, transpose_a=True, transpose_b=True), lhs),
599+
collapse_sum_like(_nn.batch_matmul(grad, lhs, transpose_a=True, transpose_b=True), rhs),
600+
]
601+
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False):
602+
# ki, kj -> ij
603+
# kj, ij -> ki
604+
# ki, ij -> kj
605+
return [
606+
collapse_sum_like(
607+
_nn.batch_matmul(rhs, grad, transpose_a=False, transpose_b=True), lhs
608+
),
609+
collapse_sum_like(
610+
_nn.batch_matmul(lhs, grad, transpose_a=False, transpose_b=False), rhs
611+
),
612+
]
613+
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True):
614+
# ik, jk -> ij
615+
# ij, jk -> ik
616+
# ij, ik -> jk
617+
# Keep using NT format batch_matmul here for not involving extra ops
618+
# TODO(jcf94): Merge all to normal batch_matmul when it is finally ready
619+
return [
620+
collapse_sum_like(
621+
_nn.batch_matmul(
622+
grad,
623+
transpose(rhs, [0, 2, 1]),
624+
transpose_a=False,
625+
transpose_b=True,
626+
),
627+
lhs,
628+
),
629+
collapse_sum_like(
630+
_nn.batch_matmul(
631+
transpose(grad, [0, 2, 1]),
632+
transpose(lhs, [0, 2, 1]),
633+
transpose_a=False,
634+
transpose_b=True,
635+
),
636+
rhs,
637+
),
638+
]
639+
# (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False)
640+
# ik, kj -> ij
641+
# ij, kj -> ik
642+
# ik, ij -> kj
593643
return [
594-
collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs),
595-
collapse_sum_like(
596-
_nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs
597-
),
644+
collapse_sum_like(_nn.batch_matmul(grad, rhs, transpose_a=False, transpose_b=True), lhs),
645+
collapse_sum_like(_nn.batch_matmul(lhs, grad, transpose_a=True, transpose_b=False), rhs),
598646
]
599647

600648

python/tvm/relay/op/nn/_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,8 +1279,8 @@ def dense_pack_shape_func(attrs, inputs, _):
12791279
def _batch_matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b):
12801280
out = output_tensor((tensor_a_shape.shape[0],), "int64")
12811281
out[0] = max(tensor_a_shape[0], tensor_b_shape[0])
1282-
out[1] = tensor_a_shape[2 if transpose_a else 1]
1283-
out[2] = tensor_b_shape[1 if transpose_b else 2]
1282+
out[1] = tensor_a_shape[2] if transpose_a else tensor_a_shape[1]
1283+
out[2] = tensor_b_shape[1] if transpose_b else tensor_b_shape[2]
12841284

12851285
return out
12861286

python/tvm/relay/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ class DenseAttrs(Attrs):
7474
"""Attributes for nn.dense"""
7575

7676

77+
@tvm._ffi.register_object("relay.attrs.BatchMatmulAttrs")
78+
class BatchMatmulAttrs(Attrs):
79+
"""Attributes for nn.batch_matmul"""
80+
81+
7782
@tvm._ffi.register_object("relay.attrs.SoftmaxAttrs")
7883
class SoftmaxAttrs(Attrs):
7984
"""Attributes for nn.softmax"""

python/tvm/relay/op/strategy/cuda.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
819819
"""batch_matmul cuda strategy"""
820820
strategy = _op.OpStrategy()
821821
x, y = inputs
822-
if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32":
822+
if (
823+
x.dtype == "int8"
824+
and y.dtype == "int8"
825+
and out_type.dtype == "int32"
826+
and attrs["transpose_a"] == False
827+
and attrs["transpose_b"] == True
828+
):
823829
strategy.add_implementation(
824830
wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True),
825831
wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8),
@@ -840,7 +846,12 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
840846
name="batch_matmul_cublas.cuda",
841847
plevel=15,
842848
)
843-
if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
849+
if (
850+
target.kind.name == "cuda"
851+
and nvcc.have_tensorcore(target=target)
852+
and attrs["transpose_a"] == False
853+
and attrs["transpose_b"] == True
854+
):
844855
x, y = inputs
845856
_, M, K = get_const_tuple(x.shape)
846857
_, N, K = get_const_tuple(y.shape)

python/tvm/topi/cuda/batch_matmul.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,49 @@
2727

2828

2929
@autotvm.register_topi_compute("batch_matmul.cuda")
30-
def batch_matmul(cfg, x, y, out_shape=None):
31-
"""Compute conv2d with NCHW layout"""
32-
return nn.batch_matmul(x, y)
30+
def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True):
31+
"""Compute batch matrix multiplication of `tensor_a` and `tensor_b`.
32+
33+
Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format
34+
(transpose_a=False, transpose_b=True) by default.
35+
36+
Parameters
37+
----------
38+
cfg : ConfigSpace
39+
Autotvm tuning space config file.
40+
41+
tensor_a : tvm.te.Tensor
42+
3-D with shape [batch, M, K] or [batch, K, M].
43+
44+
tensor_b : tvm.te.Tensor
45+
3-D with shape [batch, K, N] or [batch, N, K].
46+
47+
out_shape : List[Optional]
48+
Explicit intended output shape of the computation. Can be useful in cases
49+
with dynamic input shapes.
50+
51+
out_dtype : Optional[str]
52+
Specifies the output data type for mixed precision batch matmul.
53+
54+
transpose_a : Optional[bool] = False
55+
Whether the first tensor is in transposed format.
56+
57+
transpose_b : Optional[bool] = True
58+
Whether the second tensor is in transposed format.
59+
60+
Returns
61+
-------
62+
output : tvm.te.Tensor
63+
3-D with shape [batch, M, N]
64+
"""
65+
return nn.batch_matmul(
66+
x,
67+
y,
68+
oshape=out_shape,
69+
out_dtype=out_dtype,
70+
transpose_a=transpose_a,
71+
transpose_b=transpose_b,
72+
)
3373

3474

3575
@autotvm.register_topi_schedule("batch_matmul.cuda")
@@ -140,20 +180,37 @@ def _callback(op):
140180

141181

142182
@autotvm.register_topi_compute("batch_matmul_cublas.cuda")
143-
def batch_matmul_cublas(cfg, x, y, out_shape=None, transpose_a=False, transpose_b=True):
144-
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
145-
data in batch.
183+
def batch_matmul_cublas(
184+
cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
185+
):
186+
"""Compute batch matrix multiplication of `x` and `y`.
187+
188+
Both `x` and `y` can be transposed. For legacy reason, we use NT format
189+
(transpose_a=False, transpose_b=True) by default.
146190
147191
Parameters
148192
----------
193+
cfg : ConfigSpace
194+
Autotvm tuning space config file.
195+
149196
x : tvm.te.Tensor
150-
3-D with shape [batch, M, K]
197+
3-D with shape [batch, M, K] or [batch, K, M].
151198
152199
y : tvm.te.Tensor
153-
3-D with shape [batch, N, K]
200+
3-D with shape [batch, K, N] or [batch, N, K].
154201
155-
out_shape : None
156-
The output shape
202+
out_shape : List[Optional]
203+
Explicit intended output shape of the computation. Can be useful in cases
204+
with dynamic input shapes.
205+
206+
out_dtype : Optional[str]
207+
Specifies the output data type for mixed precision batch matmul.
208+
209+
transpose_a : Optional[bool] = False
210+
Whether the first tensor is in transposed format.
211+
212+
transpose_b : Optional[bool] = True
213+
Whether the second tensor is in transposed format.
157214
158215
Returns
159216
-------
@@ -181,7 +238,31 @@ def schedule_batch_matmul_cublas(_, outs):
181238

182239
@autotvm.register_topi_compute("batch_matmul_int8.cuda")
183240
def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None):
184-
"""Batch Matmul operator for int8 on CUDA"""
241+
"""Batch Matmul operator for int8 on CUDA.
242+
243+
Parameters
244+
----------
245+
cfg : ConfigSpace
246+
Autotvm tuning space config file.
247+
248+
x : tvm.te.Tensor
249+
3-D with shape [batch, M, K] or [batch, K, M].
250+
251+
y : tvm.te.Tensor
252+
3-D with shape [batch, K, N] or [batch, N, K].
253+
254+
out_shape : List[Optional]
255+
Explicit intended output shape of the computation. Can be useful in cases
256+
with dynamic input shapes.
257+
258+
out_dtype : Optional[str]
259+
Specifies the output data type for mixed precision batch matmul.
260+
261+
Returns
262+
-------
263+
output : tvm.te.Tensor
264+
3-D with shape [batch, M, N]
265+
"""
185266
if out_dtype is None:
186267
out_dtype = x.dtype
187268

python/tvm/topi/cuda/tensorcore_alter_op.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import logging
2121
import math
22-
from tvm import relay
22+
from tvm import relay, tir
2323

2424
from .. import nn
2525

@@ -56,6 +56,14 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
5656

5757
B, M, K = x_tensor.shape
5858
B, N, K = y_tensor.shape
59+
if (
60+
isinstance(B, tir.expr.Any)
61+
or isinstance(M, tir.expr.Any)
62+
or isinstance(K, tir.expr.Any)
63+
or isinstance(N, tir.expr.Any)
64+
):
65+
# Dynamic shape cannot support alter op layout
66+
return
5967
M = M.value
6068
K = K.value
6169
N = N.value

python/tvm/topi/x86/batch_matmul.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,6 @@ def batch_matmul(
6262
output : tvm.te.Tensor
6363
3-D with shape [batch, M, N]
6464
"""
65-
if cfg.is_fallback and not transpose_a and transpose_b:
66-
B, N, K = get_const_tuple(tensor_a.shape)
67-
_default_batch_matmul_config(cfg, B, N, K)
68-
6965
return nn.batch_matmul(
7066
tensor_a,
7167
tensor_b,
@@ -145,20 +141,32 @@ def _default_batch_matmul_config(cfg, M, N, K):
145141
cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])
146142

147143

148-
def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
149-
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
150-
data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension.
144+
def batch_matmul_blas_common(cfg, tensor_a, tensor_b, out_shape, trans_a, trans_b, lib):
145+
"""Computes batch matrix multiplication of `tensor_a` and `tensor_b` when `tensor_a` and
146+
`tensor_b` are data in batch, using one of BLAS libraries. Supports broadcasting in batch
147+
dimension.
151148
152149
Parameters
153150
----------
154151
cfg : ConfigSpace
155152
Autotvm tuning space config file
156-
x : tvm.te.Tensor
157-
3-D with shape [batch, M, K]
158-
y : tvm.te.Tensor
159-
3-D with shape [batch, N, K]
160-
out_shape : tuple or None
161-
Shape of the output
153+
154+
tensor_a : tvm.te.Tensor
155+
3-D with shape [batch, M, K] or [batch, K, M].
156+
157+
tensor_b : tvm.te.Tensor
158+
3-D with shape [batch, K, N] or [batch, N, K].
159+
160+
out_shape : List[Optional]
161+
Explicit intended output shape of the computation. Can be useful in cases
162+
with dynamic input shapes.
163+
164+
trans_a : Optional[bool] = False
165+
Whether the first tensor is in transposed format.
166+
167+
trans_b : Optional[bool] = True
168+
Whether the second tensor is in transposed format.
169+
162170
lib : A contrib module which implements batch_matmul function
163171
cblas and mkl are supported
164172
@@ -167,23 +175,33 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
167175
output : tvm.te.Tensor
168176
3-D with shape [batch, M, N]
169177
"""
170-
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
171-
XB, M, XK = get_const_tuple(x.shape)
172-
YB, N, YK = get_const_tuple(y.shape)
178+
assert len(tensor_a.shape) == 3 and len(tensor_b.shape) == 3, "only support 3-dim batch_matmul"
179+
if trans_a:
180+
XB, XK, M = get_const_tuple(tensor_a.shape)
181+
else:
182+
XB, M, XK = get_const_tuple(tensor_a.shape)
183+
if trans_b:
184+
YB, N, YK = get_const_tuple(tensor_b.shape)
185+
else:
186+
YB, YK, N = get_const_tuple(tensor_a.shape)
173187
assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match"
174188
assert XK == YK, "shapes of x and y is inconsistent"
175189
if out_shape is not None:
176190
assert out_shape[0] in (XB, YB), "got invalid output shape"
177191
assert out_shape[1] == M, "got invalid output shape"
178192
assert out_shape[2] == N, "got invalid output shape"
179193
cfg.add_flop(XB * M * N * XK * 2)
180-
return lib.batch_matmul(x, y, False, True)
194+
return lib.batch_matmul(tensor_a, tensor_b, trans_a, trans_b)
181195

182196

183197
@autotvm.register_topi_compute("batch_matmul_cblas.x86")
184-
def batch_matmul_cblas(cfg, x, y, out_shape=None):
198+
def batch_matmul_cblas(
199+
cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
200+
):
185201
"""Compute batch_matmul using cblas"""
186-
return batch_matmul_blas_common(cfg, x, y, out_shape, cblas)
202+
return batch_matmul_blas_common(
203+
cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, cblas
204+
)
187205

188206

189207
@autotvm.register_topi_schedule("batch_matmul_cblas.x86")
@@ -193,9 +211,13 @@ def schedule_batch_matmul_cblas(_, outs):
193211

194212

195213
@autotvm.register_topi_compute("batch_matmul_mkl.x86")
196-
def batch_matmul_mkl(cfg, x, y, out_shape=None):
214+
def batch_matmul_mkl(
215+
cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
216+
):
197217
"""Compute batch_matmul using mkl"""
198-
return batch_matmul_blas_common(cfg, x, y, out_shape, mkl)
218+
return batch_matmul_blas_common(
219+
cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, mkl
220+
)
199221

200222

201223
@autotvm.register_topi_schedule("batch_matmul_mkl.x86")

0 commit comments

Comments
 (0)