Skip to content

Commit

Permalink
[MXNET-310] [ONNX-MXNet] API to import ONNX models into Gluon. (apach…
Browse files Browse the repository at this point in the history
…e#10605)

* gluon import

* gluon tests

* shape issues.

* remove the dim_change list

* onnx backend tests

* changes to match onnx op set version 7

* fix

* lint fix

* add new folder

* fix

* fix

* rename test file

* comments

* comment fix

* check for opset differences.

* fix

* bcast test
  • Loading branch information
anirudhacharya authored and anirudh2290 committed Jun 4, 2018
1 parent 21997a2 commit f754498
Show file tree
Hide file tree
Showing 18 changed files with 680 additions and 538 deletions.
4 changes: 2 additions & 2 deletions ci/docker/install/ubuntu_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ echo "Installing libprotobuf-dev and protobuf-compiler ..."
apt-get install -y libprotobuf-dev protobuf-compiler

echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX and tabulate ..."
pip2 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.5.2 onnx==1.1.1 Pillow==5.0.0 tabulate==0.7.5
pip3 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.5.2 onnx==1.1.1 Pillow==5.0.0 tabulate==0.7.5
pip2 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.5.2 onnx==1.2.1 Pillow==5.0.0 tabulate==0.7.5
pip3 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.5.2 onnx==1.2.1 Pillow==5.0.0 tabulate==0.7.5
5 changes: 3 additions & 2 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,9 @@ integrationtest_ubuntu_cpu_onnx() {
set -ex
export PYTHONPATH=./python/
python example/onnx/super_resolution.py
pytest tests/python-pytest/onnx/onnx_backend_test.py
pytest tests/python-pytest/onnx/onnx_test.py
pytest tests/python-pytest/onnx/import/mxnet_backend_test.py
pytest tests/python-pytest/onnx/import/onnx_import_test.py
pytest tests/python-pytest/onnx/import/gluon_backend_test.py
}

integrationtest_ubuntu_gpu_python() {
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/contrib/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
"""Module for ONNX model format support for Apache MXNet."""

from ._import.import_model import import_model, get_model_metadata
from ._import.import_to_gluon import import_to_gluon
1 change: 1 addition & 0 deletions python/mxnet/contrib/onnx/_import/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
"""ONNX Import module"""
from . import import_model
from . import import_onnx
from . import import_to_gluon
50 changes: 43 additions & 7 deletions python/mxnet/contrib/onnx/_import/import_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
""" Support import export formats."""
from __future__ import absolute_import as _abs
from .... import symbol
from .... import cpu, gpu
from .... import ndarray as nd
from ....base import string_types
from .import_helper import _convert_map as convert_map
Expand All @@ -33,6 +34,9 @@ def __init__(self):
self._params = {}
self._num_input = 0
self._num_param = 0
self.aux_dict = {}
self.arg_dict = {}
self.model_metadata = {}

def _convert_operator(self, node_name, op_name, attrs, inputs):
"""Convert from onnx operator to mxnet operator.
Expand Down Expand Up @@ -84,6 +88,8 @@ def from_onnx(self, graph):
params : dict
A dict of name: nd.array pairs, used as pretrained weights
"""
#get input, output shapes
self.model_metadata = self.get_graph_metadata(graph)
# parse network inputs, aka parameters
for init_tensor in graph.initializer:
if not init_tensor.name.strip():
Expand All @@ -99,10 +105,6 @@ def from_onnx(self, graph):
else:
self._nodes[i.name] = symbol.Variable(name=i.name)

# For storing arg and aux params for the graph.
auxDict = {}
argDict = {}

# constructing nodes, nodes are stored as directed acyclic graph
# converting NodeProto message
for node in graph.node:
Expand All @@ -119,18 +121,18 @@ def from_onnx(self, graph):
# splitting params into args and aux params
for args in mxnet_sym.list_arguments():
if args in self._params:
argDict.update({args: nd.array(self._params[args])})
self.arg_dict.update({args: nd.array(self._params[args])})
for aux in mxnet_sym.list_auxiliary_states():
if aux in self._params:
auxDict.update({aux: nd.array(self._params[aux])})
self.aux_dict.update({aux: nd.array(self._params[aux])})

# now return the outputs
out = [self._nodes[i.name] for i in graph.output]
if len(out) > 1:
out = symbol.Group(out)
else:
out = out[0]
return out, argDict, auxDict
return out, self.arg_dict, self.aux_dict

def get_graph_metadata(self, graph):
"""
Expand All @@ -155,6 +157,40 @@ def get_graph_metadata(self, graph):
}
return metadata

def graph_to_gluon(self, graph, context):
"""Construct SymbolBlock from onnx graph.
Parameters
----------
graph : onnx protobuf object
The loaded onnx graph
context : str
context for mxnet module object. Should be 'CPU' or 'GPU'
Returns
-------
sym_block :gluon.nn.SymbolBlock
The returned gluon SymbolBlock
"""
sym, arg_params, aux_params = self.from_onnx(graph)
metadata = self.get_graph_metadata(graph)
data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']]
data_inputs = [symbol.var(data_name) for data_name in data_names]

ctx = gpu() if context == 'GPU' else cpu()
from ....gluon import SymbolBlock
net = SymbolBlock(outputs=sym, inputs=data_inputs)
net_params = net.collect_params()
for param in arg_params:
if param in net_params:
net_params[param].shape = arg_params[param].shape
net_params[param]._load_init(arg_params[param], ctx=ctx)
for param in aux_params:
if param in net_params:
net_params[param].shape = aux_params[param].shape
net_params[param]._load_init(aux_params[param], ctx=ctx)
return net

def _parse_array(self, tensor_proto):
"""Grab data in TensorProto and convert to numpy array."""
try:
Expand Down
48 changes: 48 additions & 0 deletions python/mxnet/contrib/onnx/_import/import_to_gluon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Import ONNX model to gluon interface"""
# pylint: disable=no-member

from .import_onnx import GraphProto

def import_to_gluon(model_file, context):
"""
Imports the ONNX model files, passed as a parameter, into Gluon SymbolBlock object.
Parameters
----------
model_file : str
ONNX model file name
context : str
context. Should be 'CPU' or 'GPU'
Returns
-------
sym_block : :class:`~mxnet.gluon.SymbolBlock`
A SymbolBlock object representing the given model file.
"""
graph = GraphProto()
try:
import onnx
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+ " install - https://github.com/onnx/onnx#installation")
model_proto = onnx.load(model_file)
net = graph.graph_to_gluon(model_proto.graph, context)
return net
Loading

0 comments on commit f754498

Please sign in to comment.