From 5956125ae49cb79b3c0ae6b946e2ae12565ab8c6 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 22 Feb 2022 15:36:31 -0700 Subject: [PATCH] [ONNX] only broadcast matmul if the shape has changed (#10321) * [ONNX] only broadcast matmul if the shape has changed * fix copy-pasta mistake --- python/tvm/relay/frontend/onnx.py | 36 +++++++++++++++++++------------ 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index db3503dd9c9e..d3ec9f7ed443 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -264,23 +264,31 @@ def flatten_to_nd(x, x_shape, nd=3): b = _op.transpose(inputs[1]) output = _op.nn.dense(a, b, out_dtype=out_dtype) else: + a = inputs[0] + b = inputs[1] # broadcast a and b - a_broadcasted_shape = _op.concatenate( - [ - out_batch, - _op.strided_slice(a_shape, [a_rank - 2], [a_rank]), - ], - 0, + a_broadcasted_shape = fold_constant( + _op.concatenate( + [ + out_batch, + _op.strided_slice(a_shape, [a_rank - 2], [a_rank]), + ], + 0, + ) ) - b_broadcasted_shape = _op.concatenate( - [ - out_batch, - _op.strided_slice(b_shape, [b_rank - 2], [b_rank]), - ], - 0, + b_broadcasted_shape = fold_constant( + _op.concatenate( + [ + out_batch, + _op.strided_slice(b_shape, [b_rank - 2], [b_rank]), + ], + 0, + ) ) - a = _op.transform.broadcast_to(inputs[0], fold_constant(a_broadcasted_shape)) - b = _op.transform.broadcast_to(inputs[1], fold_constant(b_broadcasted_shape)) + if not tvm.ir.structural_equal(a_shape, a_broadcasted_shape): + a = _op.transform.broadcast_to(a, a_broadcasted_shape) + if not tvm.ir.structural_equal(b_shape, b_broadcasted_shape): + b = _op.transform.broadcast_to(b, b_broadcasted_shape) # Convert a and b into 3 dimensional tensors. a = flatten_to_nd(a, shape_of(a), 3) b = flatten_to_nd(b, shape_of(b), 3)