Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-106] [ONNX_MXNet] Change parameter names in imported model (#1…
Browse files Browse the repository at this point in the history
…0472)

* fix param names in model

* corresponding changes to tutorials

* test rendering

* add comments to data name fetch stmt.
  • Loading branch information
anirudhacharya authored and cjolivier01 committed Apr 10, 2018
1 parent 656e352 commit 33823d3
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 46 deletions.
4 changes: 2 additions & 2 deletions docs/tutorials/onnx/fine_tuning_gluon.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ sym.get_internals()



```<Symbol group [input_0, param_0, param_1, convolution0, relu0, lrn0, pad0, pooling0, param_2, param_3, convolution1, relu1, lrn1, pad1, pooling1, param_4, param_5, convolution2, relu2, param_6, param_7, convolution3, relu3, param_8, param_9, convolution4, relu4, pad2, pooling2, _mulscalar0, param_10, param_11, _mulscalar1, fullyconnected0, relu5, _mulscalar2, param_12, param_13, _mulscalar3, fullyconnected1, relu6, _mulscalar4, param_14, param_15, _mulscalar5, fullyconnected2, softmax0]>```<!--notebook-skip-line-->
```<Symbol group [gpu_0/data_0, gpu_0/conv1_w_0, gpu_0/conv1_b_0, convolution0, relu0, lrn0, pad0, pooling0, gpu_0/conv2_w_0, gpu_0/conv2_b_0, convolution1, relu1, lrn1, pad1, pooling1, gpu_0/conv3_w_0, gpu_0/conv3_b_0, convolution2, relu2, gpu_0/conv4_w_0, gpu_0/conv4_b_0, convolution3, relu3, gpu_0/conv5_w_0, gpu_0/conv5_b_0, convolution4, relu4, pad2, pooling2, flatten0, gpu_0/fc6_w_0, linalg_gemm20, gpu_0/fc6_b_0, _mulscalar0, broadcast_add0, relu5, flatten1, gpu_0/fc7_w_0, linalg_gemm21, gpu_0/fc7_b_0, _mulscalar1, broadcast_add1, relu6, flatten2, gpu_0/fc8_w_0, linalg_gemm22, gpu_0/fc8_b_0, _mulscalar2, broadcast_add2, softmax0]>```<!--notebook-skip-line-->



Expand Down Expand Up @@ -258,7 +258,7 @@ We create a symbol block that is going to hold all our pre-trained layers, and a


```python
pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var('input_0'))
pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var('gpu_0/data_0'))
net_params = pre_trained.collect_params()
for param in new_arg_params:
if param in net_params:
Expand Down
15 changes: 13 additions & 2 deletions docs/tutorials/onnx/inference_on_onnx_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,22 @@ We pick a context, GPU if available, otherwise CPU
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
```

And load them into a MXNet Gluon symbol block. For ONNX models the default input name is `input_0`.
We obtain the data names of the inputs to the model, by listing all the inputs to the symbol graph and excluding the argument and auxiliary parameters from that list:

```python
data_names = [graph_input for graph_input in sym.list_inputs()
if graph_input not in arg_params and graph_input not in aux_params]
print(data_names)
```


```['gpu_0/data_0']```


And load them into a MXNet Gluon symbol block.

```python
net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('input_0'))
net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('gpu_0/data_0'))
net_params = net.collect_params()
for param in arg_params:
if param in net_params:
Expand Down
23 changes: 16 additions & 7 deletions docs/tutorials/onnx/super_resolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ mx.viz.plot_network(sym, node_attrs={"shape":"oval","fixedsize":"false"})



![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png)
![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png) <!--notebook-skip-line-->



Expand All @@ -71,10 +71,19 @@ test_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]

We will use MXNet's Module API to run the inference. For this we will need to create the module, bind it to the input data and assign the loaded weights from the two parameter objects - argument parameters and auxilliary parameters.

To obtain the input data names we run the following line, which picks all the inputs of the symbol graph excluding the argument and auxiliary parameters:

```python
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0',test_image.shape)], label_shapes=None)
data_names = [graph_input for graph_input in sym.list_inputs()
if graph_input not in arg and graph_input not in aux]
print(data_names)
```

```['1']```

```python
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[(data_names[0],test_image.shape)], label_shapes=None)
mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, allow_extra=True)
```

Expand Down Expand Up @@ -105,10 +114,10 @@ result_img = Image.merge(
result_img.save("super_res_output.jpg")
```

Here's the input image and the resulting output images compared. As you can see, the model was able to increase the spatial resolution from ``256x256`` to ``672x672``.
You can now compare the input image and the resulting output image. As you will notice, the model was able to increase the spatial resolution from ``256x256`` to ``672x672``.

| Input Image | Output Image |
| ----------- | ------------ |
| ![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true) | ![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true) |
| Input Image | Output Image | <!--notebook-skip-line-->
| ----------- | ------------ | <!--notebook-skip-line-->
| ![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true) | ![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true) | <!--notebook-skip-line-->

<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
8 changes: 6 additions & 2 deletions example/onnx/super_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,13 @@ def get_test_image():

def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr):
"""Perform inference on image using mxnet"""
# To fetch the data names of the input to the model we list the inputs of the symbol graph
# and exclude the argument and auxiliary parameters from the list
data_names = [graph_input for graph_input in sym.list_inputs()
if graph_input not in arg_params and graph_input not in aux_params]
# create module
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0', input_img.shape)])
mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None)
mod.bind(for_training=False, data_shapes=[(data_names[0], input_img.shape)])
mod.set_params(arg_params=arg_params, aux_params=aux_params)

# run inference
Expand Down
19 changes: 4 additions & 15 deletions python/mxnet/contrib/onnx/_import/import_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class GraphProto(object): # pylint: disable=too-few-public-methods
def __init__(self):
self._nodes = {}
self._params = {}
self._renames = {}
self._num_input = 0
self._num_param = 0

Expand Down Expand Up @@ -72,9 +71,6 @@ def _convert_operator(self, node_name, op_name, attrs, inputs):

def from_onnx(self, graph):
"""Construct symbol from onnx graph.
The inputs from onnx graph is vague, only providing "1", "2"...
For convenience, we rename the `real` input names to "input_0",
"input_1"... And renaming parameters to "param_0", "param_1"...
Parameters
----------
Expand All @@ -98,17 +94,10 @@ def from_onnx(self, graph):
for i in graph.input:
if i.name in self._params:
# i is a param instead of input
name_param = 'param_{}'.format(self._num_param)
self._num_param += 1
self._params[name_param] = self._params.pop(i.name)
self._nodes[name_param] = symbol.Variable(name=name_param,
shape=self._params[name_param].shape)
self._renames[i.name] = name_param
self._nodes[i.name] = symbol.Variable(name=i.name,
shape=self._params[i.name].shape)
else:
name_input = 'input_{}'.format(self._num_input)
self._num_input += 1
self._nodes[name_input] = symbol.Variable(name=name_input)
self._renames[i.name] = name_input
self._nodes[i.name] = symbol.Variable(name=i.name)

# For storing arg and aux params for the graph.
auxDict = {}
Expand All @@ -121,7 +110,7 @@ def from_onnx(self, graph):
node_name = node.name.strip()
node_name = node_name if node_name else None
onnx_attr = self._parse_attr(node.attribute)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
inputs = [self._nodes[i] for i in node.input]
mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs)

for k, i in zip(list(node.output), range(len(mxnet_sym.list_outputs()))):
Expand Down
13 changes: 11 additions & 2 deletions tests/python-pytest/onnx/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,18 @@ def run(self, inputs, **kwargs):
else:
raise NotImplementedError("Only CPU context is supported for now")

mod = mx.mod.Module(symbol=self.symbol, data_names=['input_0'], context=ctx,
# To fetch the data names of the input to the model we list the inputs of the symbol graph
# and exclude the argument and auxiliary parameters from the list
data_names = [graph_input for graph_input in self.symbol.list_inputs()
if graph_input not in self.arg_params and graph_input not in self.aux_params]

data_shapes = []
for idx, input_name in enumerate(data_names):
data_shapes.append((input_name, inputs[idx].shape))

mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, context=ctx,
label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)],
mod.bind(for_training=False, data_shapes=data_shapes,
label_shapes=None)
mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)

Expand Down
35 changes: 19 additions & 16 deletions tests/python-pytest/onnx/onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,27 +117,24 @@ def test_super_resolution_example():

inputs = sym.list_inputs()
assert len(inputs) == 9
for i, input_param in enumerate(['param_7', 'param_5', 'param_3', 'param_1',
'input_0', 'param_0', 'param_2', 'param_4', 'param_6']):
for i, input_param in enumerate(['9', '7', '5', '3', '1', '2', '4', '6', '8']):
assert inputs[i] == input_param

assert len(sym.list_outputs()) == 1
assert sym.list_outputs()[0] == 'reshape5_output'

attrs_keys = sym.attr_dict().keys()
assert len(attrs_keys) == 19
for i, key_item in enumerate(['reshape4', 'param_5', 'param_4', 'param_7',
'param_6', 'param_1', 'param_0', 'param_3',
'param_2', 'reshape2', 'reshape3', 'reshape0',
'reshape1', 'convolution2', 'convolution3',
'convolution0', 'convolution1', 'reshape5',
'transpose0']):
for i, key_item in enumerate(['reshape4', 'convolution2', 'convolution0',
'transpose0', '6', 'reshape0', 'reshape2',
'reshape3', '3', 'reshape1', '5', '4', '7',
'convolution1', '9', '2', 'convolution3',
'reshape5', '8']):
assert key_item in attrs_keys

param_keys = arg_params.keys()
assert len(param_keys) == 8
for i, param_item in enumerate(['param_5', 'param_4', 'param_7', 'param_6',
'param_1', 'param_0', 'param_3', 'param_2']):
for i, param_item in enumerate(['3', '2', '5', '4', '7', '6', '9', '8']):
assert param_item in param_keys

logging.info("Asserted the result of the onnx model conversion")
Expand Down Expand Up @@ -192,8 +189,10 @@ def test_bvlc_googlenet():
# run test for each test file
for input_data, output_data in zip(inputs, outputs):
# create module
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], label_shapes=None)
data_names = [graph_input for graph_input in sym.list_inputs()
if graph_input not in arg_params and graph_input not in aux_params]
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
mod.set_params(arg_params=arg_params, aux_params=aux_params,
allow_missing=True, allow_extra=True)
# run inference
Expand All @@ -214,8 +213,10 @@ def test_bvlc_reference_caffenet():
# run test for each test file
for input_data, output_data in zip(inputs, outputs):
# create module
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], label_shapes=None)
data_names = [graph_input for graph_input in sym.list_inputs()
if graph_input not in arg_params and graph_input not in aux_params]
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
mod.set_params(arg_params=arg_params, aux_params=aux_params,
allow_missing=True, allow_extra=True)
# run inference
Expand All @@ -236,8 +237,10 @@ def test_bvlc_rcnn_ilsvrc13():
# run test for each test file
for input_data, output_data in zip(inputs, outputs):
# create module
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], label_shapes=None)
data_names = [graph_input for graph_input in sym.list_inputs()
if graph_input not in arg_params and graph_input not in aux_params]
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
mod.set_params(arg_params=arg_params, aux_params=aux_params,
allow_missing=True, allow_extra=True)
# run inference
Expand Down

0 comments on commit 33823d3

Please sign in to comment.