Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
batchnorm tests (#19836)
Browse files Browse the repository at this point in the history
Co-authored-by: Wei Chu <weichu@amazon.com>
  • Loading branch information
waytrue17 and Wei Chu authored Feb 5, 2021
1 parent 7d934a7 commit f651452
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ def convert_batchnorm(node, **kwargs):

momentum = float(attrs.get("momentum", 0.9))
eps = float(attrs.get("eps", 0.001))
axis = int(attrs.get("axis", 1))

if axis != 1:
raise NotImplementedError("batchnorm axis != 1 is currently not supported.")

bn_node = onnx.helper.make_node(
"BatchNormalization",
Expand Down
12 changes: 12 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,3 +927,15 @@ def test_onnx_export_convolution(tmp_path, dtype, shape, num_filter, num_group,
**kwargs)
inputs = [x, w] if no_bias else [x, w, b]
op_export_test('convolution', M, inputs, tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64'])
@pytest.mark.parametrize('momentum', [0.9, 0.5, 0.1])
def test_onnx_export_batchnorm(tmp_path, dtype, momentum):
x = mx.nd.random.normal(0, 10, (2, 3, 4, 5)).astype(dtype)
gamma = mx.nd.random.normal(0, 10, (3)).astype(dtype)
beta = mx.nd.random.normal(0, 10, (3)).astype(dtype)
moving_mean = mx.nd.random.normal(0, 10, (3)).astype(dtype)
moving_var = mx.nd.abs(mx.nd.random.normal(0, 10, (3))).astype(dtype)
M = def_model('BatchNorm', eps=1e-5, momentum=momentum, fix_gamma=False, use_global_stats=False)
op_export_test('BatchNorm1', M, [x, gamma, beta, moving_mean, moving_var], tmp_path)

0 comments on commit f651452

Please sign in to comment.