Skip to content

Commit

Permalink
[ONNX] only broadcast matmul if the shape has changed (apache#10321)
Browse files Browse the repository at this point in the history
* [ONNX] only broadcast matmul if the shape has changed

* fix copy-pasta mistake
  • Loading branch information
Matthew Brookhart authored Feb 22, 2022
1 parent 33082e0 commit 5956125
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,23 +264,31 @@ def flatten_to_nd(x, x_shape, nd=3):
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b, out_dtype=out_dtype)
else:
a = inputs[0]
b = inputs[1]
# broadcast a and b
a_broadcasted_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(a_shape, [a_rank - 2], [a_rank]),
],
0,
a_broadcasted_shape = fold_constant(
_op.concatenate(
[
out_batch,
_op.strided_slice(a_shape, [a_rank - 2], [a_rank]),
],
0,
)
)
b_broadcasted_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(b_shape, [b_rank - 2], [b_rank]),
],
0,
b_broadcasted_shape = fold_constant(
_op.concatenate(
[
out_batch,
_op.strided_slice(b_shape, [b_rank - 2], [b_rank]),
],
0,
)
)
a = _op.transform.broadcast_to(inputs[0], fold_constant(a_broadcasted_shape))
b = _op.transform.broadcast_to(inputs[1], fold_constant(b_broadcasted_shape))
if not tvm.ir.structural_equal(a_shape, a_broadcasted_shape):
a = _op.transform.broadcast_to(a, a_broadcasted_shape)
if not tvm.ir.structural_equal(b_shape, b_broadcasted_shape):
b = _op.transform.broadcast_to(b, b_broadcasted_shape)
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(a, shape_of(a), 3)
b = flatten_to_nd(b, shape_of(b), 3)
Expand Down

0 comments on commit 5956125

Please sign in to comment.