Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Update onnx support to work with onnx 1.7.0 with most CV models #19017

Merged
merged 35 commits into from
Sep 11, 2020
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
51715d7
fix pooling_convention warning when convert model to onnx (#18529)
HaoLiuHust Aug 10, 2020
7b7141b
Prevent uninitialized variable error.
Aug 18, 2020
aa1515b
Initial work to get Dropout to work with onnx 1.7
Aug 26, 2020
77fb75f
Remove trailing whitespace for pylint.
Aug 26, 2020
ae1e74d
Fix tensor initialization for Dropout operator input.
Aug 27, 2020
0faeeef
Update Clip operator to support latest ONNX opset versions by moving …
Aug 27, 2020
e9453c5
Fix whitespace.
Aug 27, 2020
1d5b664
Add support for importing Dropout operator in ONNX opset version >= 12.
Aug 28, 2020
9c5c034
Add support for import ONNX opsets >= 11 to clip operator.
Aug 28, 2020
aabcdd5
Add optional opset_version parameter that defaults to latest opset ve…
Aug 28, 2020
edd6f53
Add optional parameter to create_model() that allows user to specify …
Aug 28, 2020
2dfa22f
Use opset_version argument to determine operator format.
Aug 28, 2020
6c4e555
Add a opset_version parameter to from_onnx() so at operator conversio…
Aug 28, 2020
7305b9d
For Clip and Dropout operators, use opset version from passed proto_o…
Aug 28, 2020
39da0fc
Use same tolerances that are in master.
Aug 31, 2020
e36c200
Change Pad operator to use inputs instead of attributes for newer ops…
Sep 1, 2020
e4a9318
Add documentation opset_version parameter.
Sep 1, 2020
85a0ea6
Add opset_version parameters to unit tests.
Sep 1, 2020
0738620
Add test script for testing inference with onnxruntime on CV models f…
Sep 1, 2020
885862d
Add license and clean up imports.
Sep 2, 2020
9bb2b47
Install onnxruntime in docker container for unit tests.
Sep 2, 2020
50d929c
Add onnxruntime to test dependencies.
Sep 2, 2020
a6e6967
Install onnxruntime into CentOS docker image.
Sep 2, 2020
0bfec8e
Disable testing squeezenet models for now.
Sep 2, 2020
26708e3
Update onnx version.
Sep 2, 2020
d620548
Fix typo.
Sep 2, 2020
c7b55c1
Use mx.image.imread instead of PIL module.
Sep 2, 2020
f49e47a
ONNX import: use Conv pad attribute for symmetrical padding (#18675)
Kh4L Jul 24, 2020
36d92ca
Install onnx in CentOS containers when installing python.
Sep 2, 2020
b102f78
Update import and export of some ONNX ops to support newer opset vers…
Sep 3, 2020
8bd6a64
Re-enable squeezenet model testings in onnxruntime.
Sep 3, 2020
a3ea851
Run the onnxruntime inference tests in the ONNX pipeline instead of n…
Sep 3, 2020
a5246fe
Add missed return value.
Sep 3, 2020
29dcdf3
Refactor code based on review comment.
Sep 8, 2020
d597b5a
Since the onnx tests are only run on ubuntu_cpu images, we don't need…
Sep 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix pooling_convention warning when convert model to onnx (#18529)
* fix  pooling_convention warning

* fix pooling_convention warning

* fix lint

Co-authored-by: JackieWu <wkcn@live.cn>
  • Loading branch information
2 people authored and Joe Evans committed Aug 26, 2020
commit 51715d7b2604762a3fbc56e583042f106a6a03e2
41 changes: 27 additions & 14 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,12 +648,13 @@ def convert_pooling(node, **kwargs):
p_value = attrs.get('p_value', 'None')

pooling_convention = attrs.get('pooling_convention', 'valid')

ceil_mode = False
if pooling_convention == 'full':
pooling_warning = "Pooling: ONNX currently doesn't support pooling_convention. " \
"This might lead to shape or accuracy issues. " \
"https://github.com/onnx/onnx/issues/549"

if onnx.__version__ < "1.5.0":
pooling_warning = "Pooling: ONNX lower than 1.5.0 doesn't support pooling_convention. " \
"This might lead to shape or accuracy issues. " \
"https://github.com/onnx/onnx/issues/549"
ceil_mode = True
logging.warning(pooling_warning)

pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
Expand Down Expand Up @@ -694,15 +695,27 @@ def convert_pooling(node, **kwargs):
name=name
)
else:
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)
if onnx.__version__ >= "1.5.0":
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name,
ceil_mode=ceil_mode
)
else:
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)

return [node]

Expand Down