Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Topi] Allow batch_matmul to broadcast along batch dimension. #6616

Merged
merged 9 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'master' into broadcast_matmul
  • Loading branch information
jwfromm authored Oct 4, 2020
commit 9091a74fa1010eb4a49759d361fbe57fd45cb0a0
20 changes: 14 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,16 +537,24 @@ def flatten_to_3d(x, x_shape):
return out

# Convert a and b into 3 dimensional tensors.
a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
a = flatten_to_3d(inputs[0], a_shape)
b = flatten_to_3d(inputs[1], b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
# Determine output batch dim.
batch = a_shape[0] if (len(a_shape) != len(b_shape)) else max(a_shape[0], b_shape[0])
# Reshape output to original dimensions.
return _op.reshape(output, [batch, *a_shape[1:-2], a_shape[-2], b_shape[-1]])
# Compute output shape.
final_shape = _op.concatenate(
[
_op.maximum(_op.strided_slice(a_shape, [0], [1]), _op.strided_slice(b_shape, [0], [1])),
_op.strided_slice(a_shape, [1], [infer_shape(a_shape)[0] - 1]),
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
),
],
0,
)
return _op.reshape(output, final_shape)
# Otherwise a simple dense op will get the job done.
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t)
Expand Down
18 changes: 11 additions & 7 deletions python/tvm/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def batch_matmul(x, y, oshape=None):

y : tvm.te.Tensor
3-D with shape [batch, N, K]

oshape : List[Optional]
Explicit intended output shape of the computation. Can be useful in cases
with dynamic input shapes.

Returns
-------
Expand All @@ -42,14 +46,14 @@ def batch_matmul(x, y, oshape=None):
y_shape = get_const_tuple(y.shape)
XB = x_shape[0]
YB = y_shape[0]
assert (XB == YB) or (XB == 1) or (YB == 1), "batch dimensions don't match"
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
_, M, K = x.shape
batch = max(XB, YB)
N = y.shape[1]
k = te.reduce_axis((0, K), name="k")
if oshape is None:
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
batch = max(XB, YB)
N = y.shape[1]
oshape = (batch, M, N)
return te.compute(
(batch, M, N),
lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
tag="batch_matmul",
oshape, lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), tag="batch_matmul"
)
32 changes: 21 additions & 11 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -852,17 +852,27 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;
CHECK(x->shape.size() == 3 && y->shape.size() == 3);
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
reporter->AssertEQ(y->shape[0], 1))
<< "BatchDot: batch dimensions don't match, "
<< " x shape=" << x->shape << ", y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
<< " x shape=" << x->shape << ", y shape=" << y->shape;

Array<tvm::PrimExpr> oshape = x->shape;
oshape.Set(0, max(x->shape[0], y->shape[0]));
oshape.Set(2, y->shape[1]);
bool is_dyn = false;
Array<tvm::PrimExpr> oshape;
for (size_t i = 0; i < 3; ++i) {
if (x->shape[i].as<tir::AnyNode>() != nullptr || y->shape[i].as<tir::AnyNode>() != nullptr) {
is_dyn = true;
oshape.push_back(Any());
} else {
oshape.push_back(x->shape[i]);
}
}
if (!is_dyn) {
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
reporter->AssertEQ(y->shape[0], 1))
<< "BatchDot: batch dimensions don't match, "
<< " x shape=" << x->shape << ", y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
<< " x shape=" << x->shape << ", y shape=" << y->shape;

oshape.Set(2, y->shape[1]);
}

// assign output type
reporter->Assign(types[2], TensorType(oshape, x->dtype));
Expand Down
42 changes: 36 additions & 6 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,7 @@ def test_batch_matmul(target, ctx):
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), target, ctx)
verify_batch_matmul((2, 4, 3), (3, 4), target, ctx)
verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx)
verify_batch_matmul((1, 4, 3), (2, 3, 4), target, ctx)


def verify_simple_dynamic_model(a_shape, b_shape, target, ctx):
Expand All @@ -1002,13 +1003,42 @@ def verify_model(ex, a_shape, b_shape):
mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
relu_node = helper.make_node("Relu", ["out"], ["relu"])

@tvm.testing.uses_gpu
def test_batch_matmul():
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4))
verify_batch_matmul((2, 4, 3), (3, 4))
verify_batch_matmul((2, 3, 4, 3), (3, 4))
verify_batch_matmul((1, 4, 3), (2, 3, 4))
a_array = np.random.uniform(size=a_shape).astype("float32")
b_array = np.random.uniform(size=b_shape).astype("float32")
# matmul
out_np = np.matmul(a_array, b_array)

graph = helper.make_graph(
[mul_node, relu_node],
"matmul_test",
inputs=[
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
],
outputs=[helper.make_tensor_value_info("relu", TensorProto.FLOAT, list(out_np.shape))],
)

model = helper.make_model(graph, producer_name="matmul_test")

a_anys = [relay.Any()] * len(a_shape)
b_anys = [relay.Any()] * len(b_shape)

mod, params = relay.frontend.from_onnx(model, {"a": a_anys, "b": b_anys})

ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
verify_model(ex, a_shape, b_shape)
verify_model(ex, [a * 2 for a in a_shape], [b * 2 for b in b_shape])
verify_model(ex, [a * 3 for a in a_shape], [b * 3 for b in b_shape])


# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
@tvm.testing.parametrize_targets("llvm")
def test_batch_matmul_dynamic_model(target, ctx):
verify_simple_dynamic_model((2, 3, 4, 3), (2, 3, 3, 4), target, ctx)
verify_simple_dynamic_model((2, 4, 3), (3, 4), target, ctx)
verify_simple_dynamic_model((2, 3, 4, 3), (3, 4), target, ctx)
verify_simple_dynamic_model(1, 4, 3), (2, 3, 4), target, ctx)


def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
in_array = np.random.uniform(size=shape).astype(dtype)
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.