Skip to content

Commit fd7a645

Browse files
jwfrommTushar Dey
authored andcommitted
[Topi] Allow batch_matmul to broadcast along batch dimension. (apache#6616)
* Allow batch_matmul to broadcast along batch dimension. * Added typerel checking. * Fix style issue and respond to feedback. * Fix style. * More formatting issues :( * Fix issues after merge. * Comment update. * Small tweak.
1 parent 371848d commit fd7a645

File tree

9 files changed

+43
-112
lines changed

9 files changed

+43
-112
lines changed

include/tvm/topi/nn/batch_matmul.h

Lines changed: 0 additions & 67 deletions
This file was deleted.

python/tvm/relay/frontend/onnx.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -539,15 +539,6 @@ def flatten_to_3d(x, x_shape):
539539
# Convert a and b into 3 dimensional tensors.
540540
a = flatten_to_3d(inputs[0], a_shape)
541541
b = flatten_to_3d(inputs[1], b_shape)
542-
# Broadcast b to match batch size of a
543-
new_b_shape = _op.concatenate(
544-
[
545-
_op.strided_slice(_op.shape_of(a), [0], [1]),
546-
_op.strided_slice(_op.shape_of(b), [1], [3]),
547-
],
548-
0,
549-
)
550-
b = _op.broadcast_to(b, new_b_shape)
551542
# Transpose matrix dimensions of b.
552543
b = _op.transpose(b, [0, 2, 1])
553544
# Perform a batch matmul.

python/tvm/topi/nn/batch_matmul.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
def batch_matmul(x, y, oshape=None):
2424
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
25-
data in batch.
25+
data in batch. Supports broadcasting for batch dimension.
2626
2727
Parameters
2828
----------
@@ -32,24 +32,30 @@ def batch_matmul(x, y, oshape=None):
3232
y : tvm.te.Tensor
3333
3-D with shape [batch, N, K]
3434
35+
oshape : List[Optional]
36+
Explicit intended output shape of the computation. Can be useful in cases
37+
with dynamic input shapes.
38+
3539
Returns
3640
-------
3741
output : tvm.te.Tensor
3842
3-D with shape [batch, M, N]
3943
"""
44+
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
45+
x_shape = get_const_tuple(x.shape)
46+
y_shape = get_const_tuple(y.shape)
47+
XB = x_shape[0]
48+
YB = y_shape[0]
49+
_, M, K = x.shape
50+
k = te.reduce_axis((0, K), name="k")
4051
if oshape is None:
41-
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
42-
x_shape = get_const_tuple(x.shape)
43-
y_shape = get_const_tuple(y.shape)
44-
assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
52+
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
4553
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
46-
batch, M, K = x.shape
54+
batch = max(XB, YB)
4755
N = y.shape[1]
48-
k = te.reduce_axis((0, K), name="k")
4956
oshape = (batch, M, N)
50-
else:
51-
_, _, K = x.shape
52-
k = te.reduce_axis((0, K), name="k")
5357
return te.compute(
54-
oshape, lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul"
58+
oshape,
59+
lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
60+
tag="batch_matmul",
5561
)

python/tvm/topi/testing/batch_matmul.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ def batch_matmul(x, y):
3535
out : numpy.ndarray
3636
3-D with shape [batch, M, N]
3737
"""
38-
batch, M, _ = x.shape
39-
N = y.shape[1]
38+
XB, M, _ = x.shape
39+
YB, N, _ = y.shape
40+
batch = max(XB, YB)
4041
out = np.zeros((batch, M, N)).astype(x.dtype)
4142
for i in range(batch):
42-
out[i] = np.dot(x[i], y[i].T)
43+
out[i] = np.dot(x[i if XB != 1 else 0], y[i if YB != 1 else 0].T)
4344
return out

python/tvm/topi/x86/batch_matmul.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
@autotvm.register_topi_compute("batch_matmul.x86")
2828
def batch_matmul(cfg, x, y, out_shape=None):
2929
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
30-
data in batch.
30+
data in batch. Supports broadcasting in batch dimension.
3131
3232
Parameters
3333
----------
@@ -45,9 +45,9 @@ def batch_matmul(cfg, x, y, out_shape=None):
4545
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
4646
XB, M, XK = get_const_tuple(x.shape)
4747
YB, N, YK = get_const_tuple(y.shape)
48-
assert XB == YB, "batch dimension doesn't match"
48+
assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match"
4949
assert XK == YK, "shapes of x and y is inconsistant"
50-
B = XB
50+
B = max(XB, YB)
5151
K = XK
5252
if out_shape is not None:
5353
assert out_shape[0] == B, "got invalid output shape"
@@ -58,7 +58,9 @@ def batch_matmul(cfg, x, y, out_shape=None):
5858

5959
k = te.reduce_axis((0, K), name="k")
6060
C = te.compute(
61-
(B, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul"
61+
(B, M, N),
62+
lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
63+
tag="batch_matmul",
6264
)
6365
return C
6466

src/relay/op/nn/nn.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <tvm/topi/nn/flatten.h>
3434
#include <tvm/topi/nn/softmax.h>
3535

36+
#include <algorithm>
3637
#include <string>
3738
#include <vector>
3839

@@ -862,8 +863,9 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
862863
}
863864
}
864865
if (!is_dyn) {
865-
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
866-
<< "BatchDot: batch dimension doesn't match, "
866+
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
867+
reporter->AssertEQ(y->shape[0], 1))
868+
<< "BatchDot: batch dimensions don't match, "
867869
<< " x shape=" << x->shape << ", y shape=" << y->shape;
868870
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
869871
<< "BatchDot: shapes of x and y is inconsistent, "

src/topi/nn.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#include <tvm/runtime/packed_func.h>
2525
#include <tvm/runtime/registry.h>
2626
#include <tvm/topi/nn.h>
27-
#include <tvm/topi/nn/batch_matmul.h>
2827
#include <tvm/topi/nn/bias_add.h>
2928
#include <tvm/topi/nn/bnn.h>
3029
#include <tvm/topi/nn/dense.h>
@@ -68,11 +67,6 @@ TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body([](TVMArgs args, TVMRetValue* r
6867
*rv = nn::bias_add(args[0], args[1], args[2]);
6968
});
7069

71-
/* Ops from nn/batch_matmul.h */
72-
TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValue* rv) {
73-
*rv = nn::batch_matmul(args[0], args[1]);
74-
});
75-
7670
/* Ops from nn/dilate.h */
7771
TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) {
7872
*rv = nn::dilate(args[0], args[1], args[2]);

tests/python/frontend/onnx/test_forward.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3628,7 +3628,6 @@ def verify_roi_align(
36283628
test_clip_min_max_as_inputs()
36293629
test_onehot()
36303630
test_matmul()
3631-
test_batch_matmul()
36323631
test_gather()
36333632
test_gatherelements()
36343633
test_gather_nd()

tests/python/topi/python/test_topi_batch_matmul.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@
3232
}
3333

3434

35-
def verify_batch_matmul(batch, M, N, K):
36-
x = te.placeholder((batch, M, K), name="x")
37-
y = te.placeholder((batch, N, K), name="y")
35+
def verify_batch_matmul(x_batch, y_batch, M, N, K):
36+
x = te.placeholder((x_batch, M, K), name="x")
37+
y = te.placeholder((y_batch, N, K), name="y")
3838
dtype = x.dtype
3939

4040
# use memoize to pickle the test data for next time use
4141
@memoize("topi.tests.test_topi_batch_matmul")
4242
def get_ref_data():
43-
a_np = np.random.uniform(size=(batch, M, K)).astype(dtype)
44-
b_np = np.random.uniform(size=(batch, N, K)).astype(dtype)
43+
a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype)
44+
b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype)
4545
c_np = tvm.topi.testing.batch_matmul(a_np, b_np)
4646
return (a_np, b_np, c_np)
4747

@@ -67,10 +67,13 @@ def check_device(device, ctx):
6767

6868
@tvm.testing.uses_gpu
6969
def test_batch_matmul():
70-
verify_batch_matmul(1, 16, 16, 32)
71-
verify_batch_matmul(5, 16, 16, 32)
72-
verify_batch_matmul(5, 16, 20, 32)
73-
verify_batch_matmul(30, 16, 20, 32)
70+
verify_batch_matmul(1, 1, 16, 16, 32)
71+
verify_batch_matmul(5, 5, 16, 16, 32)
72+
verify_batch_matmul(5, 5, 16, 20, 32)
73+
verify_batch_matmul(30, 30, 16, 20, 32)
74+
# Test batch broadcasting.
75+
verify_batch_matmul(1, 5, 16, 16, 32)
76+
verify_batch_matmul(5, 1, 16, 16, 32)
7477

7578

7679
if __name__ == "__main__":

0 commit comments

Comments
 (0)