Skip to content

Commit 3b47db7

Browse files
zhresholdtqchen
authored andcommitted
[Tutorial] mxnet (#47)
* [Tutorial] mxnet update add from_gluon add to __init__ fix tutorial and from_gluon fix doc lint merge from_mxnet fix fix fix tutorial fix fix header * fix tutorial * fix data * fix
1 parent c503e08 commit 3b47db7

File tree

3 files changed

+143
-18
lines changed

3 files changed

+143
-18
lines changed

nnvm/python/nnvm/frontend/mxnet.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,6 @@ def _from_mxnet_impl(symbol, graph):
256256
nnvm.sym.Symbol
257257
Converted symbol
258258
"""
259-
try:
260-
from mxnet import sym as mx_sym # pylint: disable=import-self
261-
except ImportError as e:
262-
raise ImportError('{}. MXNet is required to parse symbols.'.format(e))
263-
264-
if not isinstance(symbol, mx_sym.Symbol):
265-
raise ValueError("Provided {}, while MXNet symbol is expected", type(symbol))
266-
267259
if _is_mxnet_group_symbol(symbol):
268260
return [_from_mxnet_impl(s, graph) for s in symbol]
269261

@@ -294,7 +286,7 @@ def from_mxnet(symbol, arg_params=None, aux_params=None):
294286
295287
Parameters
296288
----------
297-
symbol : mxnet.Symbol
289+
symbol : mxnet.Symbol or mxnet.gluon.HybridBlock
298290
MXNet symbol
299291
300292
arg_params : dict of str to mx.NDArray
@@ -305,18 +297,36 @@ def from_mxnet(symbol, arg_params=None, aux_params=None):
305297
306298
Returns
307299
-------
308-
net: nnvm.Symbol
300+
sym : nnvm.Symbol
309301
Compatible nnvm symbol
310302
311303
params : dict of str to tvm.NDArray
312304
The parameter dict to be used by nnvm
313305
"""
314-
sym = _from_mxnet_impl(symbol, {})
315-
params = {}
316-
arg_params = arg_params if arg_params else {}
317-
aux_params = aux_params if aux_params else {}
318-
for k, v in arg_params.items():
319-
params[k] = tvm.nd.array(v.asnumpy())
320-
for k, v in aux_params.items():
321-
params[k] = tvm.nd.array(v.asnumpy())
306+
try:
307+
import mxnet as mx # pylint: disable=import-self
308+
except ImportError as e:
309+
raise ImportError('{}. MXNet is required to parse symbols.'.format(e))
310+
311+
if isinstance(symbol, mx.sym.Symbol):
312+
sym = _from_mxnet_impl(symbol, {})
313+
params = {}
314+
arg_params = arg_params if arg_params else {}
315+
aux_params = aux_params if aux_params else {}
316+
for k, v in arg_params.items():
317+
params[k] = tvm.nd.array(v.asnumpy())
318+
for k, v in aux_params.items():
319+
params[k] = tvm.nd.array(v.asnumpy())
320+
elif isinstance(symbol, mx.gluon.HybridBlock):
321+
data = mx.sym.Variable('data')
322+
sym = symbol(data)
323+
sym = _from_mxnet_impl(sym, {})
324+
params = {}
325+
for k, v in symbol.collect_params().items():
326+
params[k] = tvm.nd.array(v.data().asnumpy())
327+
elif isinstance(symbol, mx.gluon.Block):
328+
raise NotImplementedError("The dynamic Block is not supported yet.")
329+
else:
330+
msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol))
331+
raise ValueError(msg)
322332
return sym, params

nnvm/python/nnvm/testing/resnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
2424
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
2525
'''
26+
# pylint: disable=unused-argument
2627
import numpy as np
2728
from .. import symbol as sym
2829
from . utils import create_workload

nnvm/tutorials/from_mxnet.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
Compiling MXNet Models with NNVM
3+
================================
4+
**Author**: `Joshua Z. Zhang <https://zhreshold.github.io/>`_
5+
6+
This article is an introductory tutorial to deploy mxnet models with NNVM.
7+
8+
For us to begin with, mxnet module is required to be installed.
9+
10+
A quick solution is
11+
```
12+
pip install mxnet --user
13+
```
14+
or please refer to offical installation guide.
15+
https://mxnet.incubator.apache.org/versions/master/install/index.html
16+
"""
17+
# some standard imports
18+
import mxnet as mx
19+
import nnvm
20+
import tvm
21+
import numpy as np
22+
23+
######################################################################
24+
# Download Resnet18 model from Gluon Model Zoo
25+
# ---------------------------------------------
26+
# In this section, we download a pretrained imagenet model and classify an image.
27+
from mxnet.gluon.model_zoo.vision import get_model
28+
from mxnet.gluon.utils import download
29+
import Image
30+
from matplotlib import pyplot as plt
31+
block = get_model('resnet18_v1', pretrained=True)
32+
img_name = 'cat.jpg'
33+
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
34+
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
35+
'596b27d23537e5a1b5751d2b0481ef172f58b539/',
36+
'imagenet1000_clsid_to_human.txt'])
37+
synset_name = 'synset.txt'
38+
download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
39+
download(synset_url, synset_name)
40+
with open(synset_name) as f:
41+
synset = eval(f.read())
42+
image = Image.open(img_name).resize((224, 224))
43+
plt.imshow(image)
44+
plt.show()
45+
46+
def transform_image(image):
47+
image = np.array(image) - np.array([123., 117., 104.])
48+
image /= np.array([58.395, 57.12, 57.375])
49+
image = image.transpose((2, 0, 1))
50+
image = image[np.newaxis, :]
51+
return image
52+
53+
x = transform_image(image)
54+
print('x', x.shape)
55+
56+
######################################################################
57+
# Compile the Graph
58+
# -----------------
59+
# Now we would like to port the Gluon model to a portable computational graph.
60+
# It's as easy as several lines.
61+
# We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon
62+
sym, params = nnvm.frontend.from_mxnet(block)
63+
# we want a probability so add a softmax operator
64+
sym = nnvm.sym.softmax(sym)
65+
66+
######################################################################
67+
# now compile the graph
68+
import nnvm.compiler
69+
target = 'cuda'
70+
shape_dict = {'data': x.shape}
71+
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params)
72+
73+
######################################################################
74+
# Execute the portable graph on TVM
75+
# ---------------------------------
76+
# Now, we would like to reproduce the same forward computation using TVM.
77+
from tvm.contrib import graph_runtime
78+
ctx = tvm.gpu(0)
79+
dtype = 'float32'
80+
m = graph_runtime.create(graph, lib, ctx)
81+
# set inputs
82+
m.set_input('data', tvm.nd.array(x.astype(dtype)))
83+
m.set_input(**params)
84+
# execute
85+
m.run()
86+
# get outputs
87+
tvm_output = m.get_output(0, tvm.nd.empty((1000,), dtype))
88+
top1 = np.argmax(tvm_output)
89+
print('TVM prediction top-1:', top1, synset[top1])
90+
91+
######################################################################
92+
# Use MXNet symbol with pretrained weights
93+
# ----------------------------------------
94+
# MXNet often use `arg_prams` and `aux_params` to store network parameters
95+
# separately, here we show how to use these weights with existing API
96+
def block2symbol(block):
97+
data = mx.sym.Variable('data')
98+
sym = block(data)
99+
args = {}
100+
auxs = {}
101+
for k, v in block.collect_params().items():
102+
args[k] = mx.nd.array(v.data().asnumpy())
103+
return sym, args, auxs
104+
mx_sym, args, auxs = block2symbol(block)
105+
# usually we would save/load it as checkpoint
106+
mx.model.save_checkpoint('resnet18_v1', 0, mx_sym, args, auxs)
107+
# there are 'resnet18_v1-0000.params' and 'resnet18_v1-symbol.json' on disk
108+
109+
######################################################################
110+
# for a normal mxnet model, we start from here
111+
mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v1', 0)
112+
# now we use the same API to get NNVM compatible symbol
113+
nnvm_sym, nnvm_params = nnvm.frontend.from_mxnet(mx_sym, args, auxs)
114+
# repeat the same steps to run this model using TVM

0 commit comments

Comments
 (0)