Skip to content

Commit 6faacc6

Browse files
authored
[MXNET]DepthToSpace & SpaceToDepth Operator (#5408)
1 parent e149db2 commit 6faacc6

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

python/tvm/relay/frontend/mxnet.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,20 @@ def _mx_one_hot(inputs, attrs):
10731073
return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)
10741074

10751075

1076+
def _mx_depth_to_space(inputs, attrs):
1077+
assert len(inputs) == 1
1078+
new_attrs = {}
1079+
new_attrs["block_size"] = attrs.get_int("block_size")
1080+
return _op.nn.depth_to_space(*inputs, **new_attrs)
1081+
1082+
1083+
def _mx_space_to_depth(inputs, attrs):
1084+
assert len(inputs) == 1
1085+
new_attrs = {}
1086+
new_attrs["block_size"] = attrs.get_int("block_size")
1087+
return _op.nn.space_to_depth(*inputs, **new_attrs)
1088+
1089+
10761090
def _mx_contrib_fifo_buffer(inputs, attrs):
10771091
new_attrs = {}
10781092
new_attrs['axis'] = attrs.get_int('axis')
@@ -1854,6 +1868,8 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
18541868
"make_loss" : _mx_make_loss,
18551869
"_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
18561870
"one_hot" : _mx_one_hot,
1871+
"depth_to_space" : _mx_depth_to_space,
1872+
"space_to_depth" : _mx_space_to_depth,
18571873
# vision
18581874
"_contrib_BilinearResize2D" : _mx_resize,
18591875
"_contrib_MultiBoxPrior" : _mx_multibox_prior,

tests/python/frontend/mxnet/test_forward.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,38 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2):
995995
# _verify_swap_axis((4, 5), (5, 4), 0, 0)
996996

997997

998+
def test_forward_depth_to_space():
999+
def verify(shape, blocksize=2):
1000+
x = np.random.uniform(size=shape).astype("float32")
1001+
ref_res = mx.nd.depth_to_space(mx.nd.array(x), blocksize)
1002+
mx_sym = mx.sym.depth_to_space(mx.sym.var("x"), blocksize)
1003+
shape_dict = {"x": x.shape, }
1004+
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
1005+
for target, ctx in ctx_list():
1006+
for kind in ["graph", "debug"]:
1007+
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
1008+
op_res = intrp.evaluate()(x)
1009+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
1010+
1011+
verify((1, 18, 3, 3), 3)
1012+
1013+
1014+
def test_forward_space_to_depth():
1015+
def verify(shape, blocksize=2):
1016+
x = np.random.uniform(size=shape).astype("float32")
1017+
ref_res = mx.nd.space_to_depth(mx.nd.array(x), blocksize)
1018+
mx_sym = mx.sym.space_to_depth(mx.sym.var("x"), blocksize)
1019+
shape_dict = {"x": x.shape, }
1020+
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
1021+
for target, ctx in ctx_list():
1022+
for kind in ["graph", "debug"]:
1023+
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
1024+
op_res = intrp.evaluate()(x)
1025+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
1026+
1027+
verify((1, 1, 9, 9), 3)
1028+
1029+
9981030
if __name__ == '__main__':
9991031
test_forward_mlp()
10001032
test_forward_vgg()
@@ -1047,6 +1079,8 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2):
10471079
test_forward_instance_norm()
10481080
test_forward_layer_norm()
10491081
test_forward_one_hot()
1082+
test_forward_depth_to_space()
1083+
test_forward_space_to_depth()
10501084
test_forward_convolution()
10511085
test_forward_deconvolution()
10521086
test_forward_cond()

0 commit comments

Comments
 (0)