Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/docker/install/ubuntu_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ apt-get update || true
apt-get install -y libprotobuf-dev protobuf-compiler

echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX, tabulate and onnxruntime..."
pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.4.0
pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.4.0 gluonnlp
100 changes: 100 additions & 0 deletions tests/python-pytest/onnx/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import functools
import logging
import os
import random

import mxnet as mx
import numpy as np


def with_seed(seed=None):
"""
A decorator for test functions that manages rng seeds.

Parameters
----------

seed : the seed to pass to np.random and mx.random


This tests decorator sets the np, mx and python random seeds identically
prior to each test, then outputs those seeds if the test fails or
if the test requires a fixed seed (as a reminder to make the test
more robust against random data).

@with_seed()
def test_ok_with_random_data():
...

@with_seed(1234)
def test_not_ok_with_random_data():
...

Use of the @with_seed() decorator for all tests creates
tests isolation and reproducability of failures. When a
test fails, the decorator outputs the seed used. The user
can then set the environment variable MXNET_TEST_SEED to
the value reported, then rerun the test with:

pytest --verbose --capture=no <test_module_name.py>::<failing_test>

To run a test repeatedly, set MXNET_TEST_COUNT=<NNN> in the environment.
To see the seeds of even the passing tests, add '--log-level=DEBUG' to pytest.
"""
def test_helper(orig_test):
@functools.wraps(orig_test)
def test_new(*args, **kwargs):
test_count = int(os.getenv('MXNET_TEST_COUNT', '1'))
env_seed_str = os.getenv('MXNET_TEST_SEED')
for i in range(test_count):
if seed is not None:
this_test_seed = seed
log_level = logging.INFO
elif env_seed_str is not None:
this_test_seed = int(env_seed_str)
log_level = logging.INFO
else:
this_test_seed = np.random.randint(0, np.iinfo(np.int32).max)
log_level = logging.DEBUG
post_test_state = np.random.get_state()
np.random.seed(this_test_seed)
mx.random.seed(this_test_seed)
random.seed(this_test_seed)
# 'pytest --logging-level=DEBUG' shows this msg even with an ensuing core dump.
test_count_msg = '{} of {}: '.format(i+1,test_count) if test_count > 1 else ''
pre_test_msg = ('{}Setting test np/mx/python random seeds, use MXNET_TEST_SEED={}'
' to reproduce.').format(test_count_msg, this_test_seed)
on_err_test_msg = ('{}Error seen with seeded test, use MXNET_TEST_SEED={}'
' to reproduce.').format(test_count_msg, this_test_seed)
logging.log(log_level, pre_test_msg)
try:
orig_test(*args, **kwargs)
except:
# With exceptions, repeat test_msg at WARNING level to be sure it's seen.
if log_level < logging.WARNING:
logging.warning(on_err_test_msg)
raise
finally:
# Provide test-isolation for any test having this decorator
mx.nd.waitall()
np.random.set_state(post_test_state)
return test_new
return test_helper

102 changes: 81 additions & 21 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import numpy as np
import onnxruntime

from mxnet.test_utils import assert_almost_equal
from common import with_seed

import json
import os
import pytest
Expand All @@ -30,10 +33,11 @@
['apron.jpg', [411,578,638,639,689,775]],
['dolphin.jpg', [2,3,4,146,147,148,395]],
['hammerheadshark.jpg', [3,4]],
['lotus.jpg', [723,738,985]]
['lotus.jpg', [716,723,738,985]]
]

test_models = [
'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25',
'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25',
'resnet18_v1', 'resnet18_v2', 'resnet34_v1', 'resnet34_v2', 'resnet50_v1', 'resnet50_v2',
Expand All @@ -42,6 +46,7 @@
'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn'
]

@with_seed()
@pytest.mark.parametrize('model', test_models)
def test_cv_model_inference_onnxruntime(tmp_path, model):
def get_gluon_cv_model(model_name, tmp):
Expand Down Expand Up @@ -97,25 +102,80 @@ def download_test_images(tmpdir):


tmp_path = str(tmp_path)
#labels = load_imgnet_labels(tmp_path)
test_images = download_test_images(tmp_path)
sym_file, params_file = get_gluon_cv_model(model, tmp_path)
onnx_file = export_model_to_onnx(sym_file, params_file)

# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
input_name = session.get_inputs()[0].name

for img, accepted_ids in test_images:
img_data = normalize_image(os.path.join(tmp_path,img))
raw_result = session.run([], {input_name: img_data})
res = softmax(np.array(raw_result)).tolist()
class_idx = np.argmax(res)
assert(class_idx in accepted_ids)

shutil.rmtree(tmp_path)

try:
#labels = load_imgnet_labels(tmp_path)
test_images = download_test_images(tmp_path)
sym_file, params_file = get_gluon_cv_model(model, tmp_path)
onnx_file = export_model_to_onnx(sym_file, params_file)

# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
input_name = session.get_inputs()[0].name

for img, accepted_ids in test_images:
img_data = normalize_image(os.path.join(tmp_path,img))
raw_result = session.run([], {input_name: img_data})
res = softmax(np.array(raw_result)).tolist()
class_idx = np.argmax(res)
assert(class_idx in accepted_ids)

finally:
shutil.rmtree(tmp_path)


@with_seed()
@pytest.mark.parametrize('model', ['bert_12_768_12'])
def test_bert_inference_onnxruntime(tmp_path, model):
tmp_path = str(tmp_path)
try:
import gluonnlp as nlp
dataset = 'book_corpus_wiki_en_uncased'
ctx = mx.cpu(0)
model, vocab = nlp.model.get_model(
name=model,
ctx=ctx,
dataset_name=dataset,
pretrained=False,
use_pooler=True,
use_decoder=False,
use_classifier=False)
model.initialize(ctx=ctx)
model.hybridize(static_alloc=True)

batch = 5
seq_length = 16
# create synthetic test data
inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
valid_length = mx.nd.array([seq_length] * batch, dtype='float32')

seq_encoding, cls_encoding = model(inputs, token_types, valid_length)

prefix = "%s/bert" % tmp_path
model.export(prefix)
sym_file = "%s-symbol.json" % prefix
params_file = "%s-0000.params" % prefix
onnx_file = "%s.onnx" % prefix


input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, np.float32, onnx_file)


# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
onnx_inputs = [inputs, token_types, valid_length]
input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
pred_onx, cls_onx = session.run(None, input_dict)

assert_almost_equal(seq_encoding, pred_onx)
assert_almost_equal(cls_encoding, cls_onx)

finally:
shutil.rmtree(tmp_path)