Skip to content

Commit 2e6e7dc

Browse files
author
Trevor Morris
authored
[Frontend][MXNet] Add support for MXNet GroupNorm (#7409)
* Add support for MXNet GroupNorm * Fix python lint * Fix lint
1 parent 0aa90b0 commit 2e6e7dc

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

python/tvm/relay/frontend/mxnet.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,19 @@ def _mx_layer_norm(inputs, attrs):
495495
return _op.nn.layer_norm(*inputs, **new_attrs)
496496

497497

498+
def _mx_group_norm(inputs, attrs):
499+
assert len(inputs) == 3
500+
if attrs.get_bool("output_mean_var", False):
501+
raise tvm.error.OpAttributeUnimplemented(
502+
'Attribute "output_mean_var" is not supported for operator Group Norm.'
503+
)
504+
new_attrs = {}
505+
new_attrs["axis"] = 1
506+
new_attrs["num_groups"] = attrs.get_int("num_groups", 1)
507+
new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
508+
return _op.nn.group_norm(*inputs, **new_attrs)
509+
510+
498511
def _mx_slice(inputs, attrs):
499512
new_attrs = {}
500513
begin = list(attrs.get_int_tuple("begin", None))
@@ -2599,6 +2612,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
25992612
"_contrib_SyncBatchNorm": _mx_batch_norm,
26002613
"InstanceNorm": _mx_instance_norm,
26012614
"LayerNorm": _mx_layer_norm,
2615+
"GroupNorm": _mx_group_norm,
26022616
"LRN": _mx_lrn,
26032617
"L2Normalization": _mx_l2_normalize,
26042618
"slice": _mx_slice,

tests/python/frontend/mxnet/test_forward.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,38 @@ def verify(shape, axis=-1):
12631263
verify((2, 5, 6))
12641264

12651265

1266+
@tvm.testing.uses_gpu
1267+
def test_forward_group_norm():
1268+
def verify(shape, num_groups=1):
1269+
x = np.random.uniform(size=shape).astype("float32")
1270+
gamma = np.random.uniform(size=(shape[1])).astype("float32")
1271+
beta = np.random.uniform(size=(shape[1])).astype("float32")
1272+
ref_res = mx.nd.GroupNorm(
1273+
data=mx.nd.array(x),
1274+
gamma=mx.nd.array(gamma),
1275+
beta=mx.nd.array(beta),
1276+
num_groups=num_groups,
1277+
)
1278+
mx_sym = mx.sym.GroupNorm(
1279+
mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), num_groups=num_groups
1280+
)
1281+
shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
1282+
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
1283+
for target, ctx in tvm.testing.enabled_targets():
1284+
for kind in ["graph", "debug"]:
1285+
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
1286+
op_res = intrp.evaluate()(x, gamma, beta)
1287+
tvm.testing.assert_allclose(
1288+
op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5
1289+
)
1290+
1291+
verify((1, 4, 2), num_groups=4)
1292+
# TODO(trevmorr): MXNet GroupNorm implementation is bugged for cases when num_groups != num_channels
1293+
# https://github.com/apache/incubator-mxnet/pull/18199
1294+
# verify((1, 4, 2, 3), num_groups=2)
1295+
# verify((1, 4, 2, 3))
1296+
1297+
12661298
@tvm.testing.uses_gpu
12671299
def test_forward_one_hot():
12681300
def verify(indices_shape, depth, on_value, off_value, dtype):

0 commit comments

Comments
 (0)