Skip to content

Commit 74b6922

Browse files
author
Trevor Morris
authored
[Relay][MXNet] Support broadcast_like (#6561)
1 parent d3ef137 commit 74b6922

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

python/tvm/relay/frontend/mxnet.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2251,6 +2251,16 @@ def _mx_broadcast_to(inputs, attrs):
22512251
return _op.broadcast_to(data, tgt_shape)
22522252

22532253

2254+
def _mx_broadcast_like(inputs, attrs):
2255+
assert len(inputs) == 2
2256+
for axes in ["lhs_axes", "rhs_axes"]:
2257+
if axes in attrs.attrs:
2258+
raise tvm.error.OpAttributeUnImplemented(
2259+
'Attribute "{}" is not supported for operator broadcast_like.'.format(axes)
2260+
)
2261+
return _op.broadcast_to_like(*inputs)
2262+
2263+
22542264
def _mx_logical_not(inputs, input_types):
22552265
data = inputs[0]
22562266
dtype = _infer_type(data).checked_type.dtype
@@ -2410,6 +2420,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
24102420
"broadcast_logical_and": _mx_broadcast_logical(_op.logical_and),
24112421
"broadcast_logical_xor": _mx_broadcast_logical(_op.logical_xor),
24122422
"broadcast_to": _mx_broadcast_to,
2423+
"broadcast_like": _mx_broadcast_like,
24132424
"logical_not": _mx_logical_not,
24142425
"_equal": _mx_compare(_op.equal, _rename),
24152426
"_not_equal": _mx_compare(_op.not_equal, _rename),

tests/python/frontend/mxnet/test_forward.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,24 @@ def verify(input_shape, shape):
754754
verify((4, 1, 32, 32), (4, 8, 32, 32))
755755

756756

757+
@tvm.testing.uses_gpu
758+
def test_forward_broadcast_like():
759+
def verify(input_shape, like_shape):
760+
x_np = np.random.uniform(size=input_shape).astype("float32")
761+
y_np = np.random.uniform(size=like_shape).astype("float32")
762+
ref_res = mx.nd.broadcast_like(mx.nd.array(x_np), mx.nd.array(y_np))
763+
mx_sym = mx.sym.broadcast_like(mx.sym.var("x"), mx.sym.var("y"))
764+
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape, "y": like_shape})
765+
for target, ctx in tvm.testing.enabled_targets():
766+
for kind in ["graph", "debug"]:
767+
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
768+
op_res = intrp.evaluate()(x_np, y_np)
769+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
770+
771+
verify((1, 2, 3), (3, 2, 3))
772+
verify((4, 1, 32, 32), (4, 8, 32, 32))
773+
774+
757775
@tvm.testing.uses_gpu
758776
def test_forward_logical_not():
759777
a_shape = (3, 4, 5)

0 commit comments

Comments
 (0)