diff --git a/docs/tutorials/onnx/fine_tuning_gluon.md b/docs/tutorials/onnx/fine_tuning_gluon.md index 7961f9f6b8a8..4116ff631ebd 100644 --- a/docs/tutorials/onnx/fine_tuning_gluon.md +++ b/docs/tutorials/onnx/fine_tuning_gluon.md @@ -230,7 +230,7 @@ sym.get_internals() -`````` +`````` @@ -258,7 +258,7 @@ We create a symbol block that is going to hold all our pre-trained layers, and a ```python -pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var('input_0')) +pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var('gpu_0/data_0')) net_params = pre_trained.collect_params() for param in new_arg_params: if param in net_params: diff --git a/docs/tutorials/onnx/inference_on_onnx_model.md b/docs/tutorials/onnx/inference_on_onnx_model.md index 9415d0063c83..bdda820119e8 100644 --- a/docs/tutorials/onnx/inference_on_onnx_model.md +++ b/docs/tutorials/onnx/inference_on_onnx_model.md @@ -104,11 +104,22 @@ We pick a context, GPU if available, otherwise CPU ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu() ``` -And load them into a MXNet Gluon symbol block. For ONNX models the default input name is `input_0`. +We obtain the data names of the inputs to the model, by listing all the inputs to the symbol graph and excluding the argument and auxiliary parameters from that list: +```python +data_names = [graph_input for graph_input in sym.list_inputs() + if graph_input not in arg_params and graph_input not in aux_params] +print(data_names) +``` + + +```['gpu_0/data_0']``` + + +And load them into a MXNet Gluon symbol block. ```python -net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('input_0')) +net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('gpu_0/data_0')) net_params = net.collect_params() for param in arg_params: if param in net_params: diff --git a/docs/tutorials/onnx/super_resolution.md b/docs/tutorials/onnx/super_resolution.md index dc75b6606f20..36c06b743c8e 100644 --- a/docs/tutorials/onnx/super_resolution.md +++ b/docs/tutorials/onnx/super_resolution.md @@ -51,7 +51,7 @@ mx.viz.plot_network(sym, node_attrs={"shape":"oval","fixedsize":"false"}) -![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png) +![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png) @@ -71,10 +71,19 @@ test_image = np.array(img_y)[np.newaxis, np.newaxis, :, :] We will use MXNet's Module API to run the inference. For this we will need to create the module, bind it to the input data and assign the loaded weights from the two parameter objects - argument parameters and auxilliary parameters. +To obtain the input data names we run the following line, which picks all the inputs of the symbol graph excluding the argument and auxiliary parameters: ```python -mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None) -mod.bind(for_training=False, data_shapes=[('input_0',test_image.shape)], label_shapes=None) +data_names = [graph_input for graph_input in sym.list_inputs() + if graph_input not in arg and graph_input not in aux] +print(data_names) +``` + +```['1']``` + +```python +mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) +mod.bind(for_training=False, data_shapes=[(data_names[0],test_image.shape)], label_shapes=None) mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, allow_extra=True) ``` @@ -105,10 +114,10 @@ result_img = Image.merge( result_img.save("super_res_output.jpg") ``` -Here's the input image and the resulting output images compared. As you can see, the model was able to increase the spatial resolution from ``256x256`` to ``672x672``. +You can now compare the input image and the resulting output image. As you will notice, the model was able to increase the spatial resolution from ``256x256`` to ``672x672``. -| Input Image | Output Image | -| ----------- | ------------ | -| ![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true) | ![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true) | +| Input Image | Output Image | +| ----------- | ------------ | +| ![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true) | ![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true) | \ No newline at end of file diff --git a/example/onnx/super_resolution.py b/example/onnx/super_resolution.py index f7c7886d0dfe..a52f1a892a6f 100644 --- a/example/onnx/super_resolution.py +++ b/example/onnx/super_resolution.py @@ -55,9 +55,13 @@ def get_test_image(): def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr): """Perform inference on image using mxnet""" + # To fetch the data names of the input to the model we list the inputs of the symbol graph + # and exclude the argument and auxiliary parameters from the list + data_names = [graph_input for graph_input in sym.list_inputs() + if graph_input not in arg_params and graph_input not in aux_params] # create module - mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None) - mod.bind(for_training=False, data_shapes=[('input_0', input_img.shape)]) + mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None) + mod.bind(for_training=False, data_shapes=[(data_names[0], input_img.shape)]) mod.set_params(arg_params=arg_params, aux_params=aux_params) # run inference diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py b/python/mxnet/contrib/onnx/_import/import_onnx.py index 92e7cb9c64e8..5192c6f8a858 100644 --- a/python/mxnet/contrib/onnx/_import/import_onnx.py +++ b/python/mxnet/contrib/onnx/_import/import_onnx.py @@ -31,7 +31,6 @@ class GraphProto(object): # pylint: disable=too-few-public-methods def __init__(self): self._nodes = {} self._params = {} - self._renames = {} self._num_input = 0 self._num_param = 0 @@ -72,9 +71,6 @@ def _convert_operator(self, node_name, op_name, attrs, inputs): def from_onnx(self, graph): """Construct symbol from onnx graph. - The inputs from onnx graph is vague, only providing "1", "2"... - For convenience, we rename the `real` input names to "input_0", - "input_1"... And renaming parameters to "param_0", "param_1"... Parameters ---------- @@ -98,17 +94,10 @@ def from_onnx(self, graph): for i in graph.input: if i.name in self._params: # i is a param instead of input - name_param = 'param_{}'.format(self._num_param) - self._num_param += 1 - self._params[name_param] = self._params.pop(i.name) - self._nodes[name_param] = symbol.Variable(name=name_param, - shape=self._params[name_param].shape) - self._renames[i.name] = name_param + self._nodes[i.name] = symbol.Variable(name=i.name, + shape=self._params[i.name].shape) else: - name_input = 'input_{}'.format(self._num_input) - self._num_input += 1 - self._nodes[name_input] = symbol.Variable(name=name_input) - self._renames[i.name] = name_input + self._nodes[i.name] = symbol.Variable(name=i.name) # For storing arg and aux params for the graph. auxDict = {} @@ -121,7 +110,7 @@ def from_onnx(self, graph): node_name = node.name.strip() node_name = node_name if node_name else None onnx_attr = self._parse_attr(node.attribute) - inputs = [self._nodes[self._renames.get(i, i)] for i in node.input] + inputs = [self._nodes[i] for i in node.input] mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs) for k, i in zip(list(node.output), range(len(mxnet_sym.list_outputs()))): diff --git a/tests/python-pytest/onnx/backend_rep.py b/tests/python-pytest/onnx/backend_rep.py index 47ea6c1585a6..114a2eb79903 100644 --- a/tests/python-pytest/onnx/backend_rep.py +++ b/tests/python-pytest/onnx/backend_rep.py @@ -64,9 +64,18 @@ def run(self, inputs, **kwargs): else: raise NotImplementedError("Only CPU context is supported for now") - mod = mx.mod.Module(symbol=self.symbol, data_names=['input_0'], context=ctx, + # To fetch the data names of the input to the model we list the inputs of the symbol graph + # and exclude the argument and auxiliary parameters from the list + data_names = [graph_input for graph_input in self.symbol.list_inputs() + if graph_input not in self.arg_params and graph_input not in self.aux_params] + + data_shapes = [] + for idx, input_name in enumerate(data_names): + data_shapes.append((input_name, inputs[idx].shape)) + + mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, context=ctx, label_names=None) - mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], + mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None) mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params) diff --git a/tests/python-pytest/onnx/onnx_test.py b/tests/python-pytest/onnx/onnx_test.py index ddc633e28f66..36cb9abacdd4 100644 --- a/tests/python-pytest/onnx/onnx_test.py +++ b/tests/python-pytest/onnx/onnx_test.py @@ -117,8 +117,7 @@ def test_super_resolution_example(): inputs = sym.list_inputs() assert len(inputs) == 9 - for i, input_param in enumerate(['param_7', 'param_5', 'param_3', 'param_1', - 'input_0', 'param_0', 'param_2', 'param_4', 'param_6']): + for i, input_param in enumerate(['9', '7', '5', '3', '1', '2', '4', '6', '8']): assert inputs[i] == input_param assert len(sym.list_outputs()) == 1 @@ -126,18 +125,16 @@ def test_super_resolution_example(): attrs_keys = sym.attr_dict().keys() assert len(attrs_keys) == 19 - for i, key_item in enumerate(['reshape4', 'param_5', 'param_4', 'param_7', - 'param_6', 'param_1', 'param_0', 'param_3', - 'param_2', 'reshape2', 'reshape3', 'reshape0', - 'reshape1', 'convolution2', 'convolution3', - 'convolution0', 'convolution1', 'reshape5', - 'transpose0']): + for i, key_item in enumerate(['reshape4', 'convolution2', 'convolution0', + 'transpose0', '6', 'reshape0', 'reshape2', + 'reshape3', '3', 'reshape1', '5', '4', '7', + 'convolution1', '9', '2', 'convolution3', + 'reshape5', '8']): assert key_item in attrs_keys param_keys = arg_params.keys() assert len(param_keys) == 8 - for i, param_item in enumerate(['param_5', 'param_4', 'param_7', 'param_6', - 'param_1', 'param_0', 'param_3', 'param_2']): + for i, param_item in enumerate(['3', '2', '5', '4', '7', '6', '9', '8']): assert param_item in param_keys logging.info("Asserted the result of the onnx model conversion") @@ -192,8 +189,10 @@ def test_bvlc_googlenet(): # run test for each test file for input_data, output_data in zip(inputs, outputs): # create module - mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None) - mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], label_shapes=None) + data_names = [graph_input for graph_input in sym.list_inputs() + if graph_input not in arg_params and graph_input not in aux_params] + mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) + mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True, allow_extra=True) # run inference @@ -214,8 +213,10 @@ def test_bvlc_reference_caffenet(): # run test for each test file for input_data, output_data in zip(inputs, outputs): # create module - mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None) - mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], label_shapes=None) + data_names = [graph_input for graph_input in sym.list_inputs() + if graph_input not in arg_params and graph_input not in aux_params] + mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) + mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True, allow_extra=True) # run inference @@ -236,8 +237,10 @@ def test_bvlc_rcnn_ilsvrc13(): # run test for each test file for input_data, output_data in zip(inputs, outputs): # create module - mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None) - mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], label_shapes=None) + data_names = [graph_input for graph_input in sym.list_inputs() + if graph_input not in arg_params and graph_input not in aux_params] + mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) + mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True, allow_extra=True) # run inference