Skip to content

Commit 389a00f

Browse files
zhresholdtqchen
authored andcommitted
init mxnet converter (apache#27)
graph backup update finish mxnet converter fix fix various add tests fix add multi networks uses model_zoo fix tests minor fix fix graph fix
1 parent 2b3d2e2 commit 389a00f

File tree

9 files changed

+947
-0
lines changed

9 files changed

+947
-0
lines changed

nnvm/python/nnvm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
from . import symbol as sym
88
from . import symbol
99
from ._base import NNVMError
10+
from . import frontend
1011

1112
__version__ = _base.__version__
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Frontend package."""
2+
from __future__ import absolute_import
3+
from .mxnet import from_mxnet

nnvm/python/nnvm/frontend/mxnet.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
"""MXNet symbol frontend."""
2+
from __future__ import absolute_import as _abs
3+
import json
4+
from .. import symbol as _sym
5+
6+
__all__ = ['from_mxnet']
7+
8+
def _required_attr(attr, key):
9+
assert isinstance(attr, dict)
10+
if key not in attr:
11+
raise AttributeError("Required attribute {} not found.".format(key))
12+
return attr[key]
13+
14+
def _raise_not_supported(attr, op='nnvm'):
15+
err = "{} is not supported in {}.".format(attr, op)
16+
raise NotImplementedError(err)
17+
18+
def _warn_not_used(attr, op='nnvm'):
19+
import warnings
20+
err = "{} is ignored in {}.".format(attr, op)
21+
warnings.warn(err)
22+
23+
def _parse_tshape(tshape):
24+
"""Parse tshape in string."""
25+
return [int(x.strip()) for x in tshape.strip('()').split(',')]
26+
27+
def _parse_bool_str(attr, key, default='False'):
28+
"""Parse bool string to boolean."""
29+
return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes']
30+
31+
def _rename(new_name):
32+
def impl(attr):
33+
return new_name, attr
34+
return impl
35+
36+
def _variable(attrs):
37+
return "Variable", attrs
38+
39+
def _pooling(attrs):
40+
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
41+
if len(kernel) != 2:
42+
_raise_not_supported('non-2d kernel', 'pool_2d')
43+
global_pool = 'global' if _parse_bool_str(attrs, 'global_pool') else ''
44+
pool_type = _required_attr(attrs, 'pool_type')
45+
if pool_type not in ['avg', 'max']:
46+
_raise_not_supported('non-avg/max', 'pool2d')
47+
op_name, new_attrs = '_'.join([global_pool, pool_type, 'pool2d']).strip('_'), {}
48+
# new_attrs['layout'] = 'NCHW'
49+
if not global_pool:
50+
new_attrs['pool_size'] = kernel
51+
new_attrs['strides'] = attrs.get('stride', (1, 1))
52+
new_attrs['padding'] = attrs.get('pad', (0, 0))
53+
new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full')
54+
return op_name, new_attrs
55+
56+
def _batch_norm(attrs):
57+
if _parse_bool_str(attrs, 'output_mean_var'):
58+
_raise_not_supported('output_mean_var', 'batch_norm')
59+
if _parse_bool_str(attrs, 'fix_gamma'):
60+
_warn_not_used('fix_gamma', 'batch_norm')
61+
if _parse_bool_str(attrs, 'use_global_stats'):
62+
_warn_not_used('use_global_stats', 'batch_norm')
63+
if _parse_bool_str(attrs, 'momentum'):
64+
_warn_not_used('momentum', 'batch_norm')
65+
op_name, new_attrs = 'batch_norm', {}
66+
new_attrs['axis'] = attrs.get('axis', 1)
67+
new_attrs['epsilon'] = attrs.get('eps', 0.001)
68+
new_attrs['center'] = True
69+
new_attrs['scale'] = True
70+
return op_name, new_attrs
71+
72+
def _concat(attrs):
73+
op_name = 'concatenate'
74+
new_attrs = {'axis': attrs.get('dim', 1)}
75+
return op_name, new_attrs
76+
77+
def _conv2d(attrs):
78+
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
79+
if len(kernel) != 2:
80+
_raise_not_supported('non 2d kernel', 'conv2d')
81+
layout = attrs.get('layout', 'NCHW')
82+
if layout not in ['NCHW', 'NHWC']:
83+
_raise_not_supported('layout: ' + layout, 'conv2d')
84+
op_name, new_attrs = 'conv2d', {}
85+
new_attrs['channels'] = _required_attr(attrs, 'num_filter')
86+
new_attrs['kernel_size'] = kernel
87+
new_attrs['strides'] = attrs.get('stride', (1, 1))
88+
new_attrs['padding'] = attrs.get('pad', (0, 0))
89+
new_attrs['dilation'] = attrs.get('dilate', (1, 1))
90+
new_attrs['groups'] = attrs.get('num_group', 1)
91+
new_attrs['layout'] = layout
92+
new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False'
93+
return op_name, new_attrs
94+
95+
def _conv2d_transpose(attrs):
96+
if 'target_shape' in attrs:
97+
_raise_not_supported('target_shape', 'conv2d_transpose')
98+
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
99+
if len(kernel) != 2:
100+
_raise_not_supported('non-2d kernel', 'conv2d_transpose')
101+
layout = attrs.get('layout', 'NCHW')
102+
if layout not in ['NCHW', 'NHWC']:
103+
_raise_not_supported('layout: ' + layout, 'conv2d_transpose')
104+
op_name, new_attrs = 'conv2d_transpose', {}
105+
new_attrs['channels'] = _required_attr(attrs, 'num_filter')
106+
new_attrs['kernel_size'] = kernel
107+
new_attrs['strides'] = attrs.get('stride', (1, 1))
108+
new_attrs['output_padding'] = attrs.get('adj', (0, 0))
109+
new_attrs['padding'] = attrs.get('pad', (0, 0))
110+
new_attrs['dilation'] = attrs.get('dilate', (1, 1))
111+
new_attrs['groups'] = attrs.get('num_group', 1)
112+
new_attrs['layout'] = layout
113+
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
114+
return op_name, new_attrs
115+
116+
def _dense(attrs):
117+
op_name, new_attrs = 'dense', {}
118+
new_attrs['units'] = _required_attr(attrs, 'num_hidden')
119+
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
120+
return op_name, new_attrs
121+
122+
def _dropout(attrs):
123+
op_name, new_attrs = 'dropout', {}
124+
new_attrs['rate'] = attrs.get('p', 0.5)
125+
return op_name, new_attrs
126+
127+
def _leaky_relu(attrs):
128+
act_type = _required_attr(attrs, 'act_type')
129+
if act_type not in ['leaky']:
130+
_raise_not_supported('act_type: ' + act_type)
131+
op_name, new_attrs = 'leaky_relu', {}
132+
new_attrs['alpha'] = attrs.get('slope', 0.25)
133+
return op_name, new_attrs
134+
135+
def _activations(attrs):
136+
act_type = _required_attr(attrs, 'act_type')
137+
if act_type not in ['relu', 'sigmoid', 'tanh']:
138+
_raise_not_supported('act_type: ' + act_type)
139+
op_name, new_attrs = act_type, {}
140+
return op_name, new_attrs
141+
142+
def _reshape(attrs):
143+
if _parse_bool_str(attrs, 'reverse'):
144+
_raise_not_supported('reverse', 'reshape')
145+
op_name, new_attrs = 'reshape', {}
146+
new_attrs['shape'] = _required_attr(attrs, 'shape')
147+
return op_name, new_attrs
148+
149+
def _split(attrs):
150+
if _parse_bool_str(attrs, 'squeeze_axis'):
151+
_raise_not_supported('squeeze_axis', 'split')
152+
op_name, new_attrs = 'split', {}
153+
new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs')
154+
new_attrs['axis'] = attrs.get('axis', 1)
155+
return op_name, new_attrs
156+
157+
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
158+
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
159+
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
160+
'__rsub_scalar__', '__sub_scalar__', '__sub_symbol__',
161+
'broadcast_add', 'broadcast_div', 'broadcast_mul',
162+
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
163+
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
164+
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
165+
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']
166+
167+
_convert_map = {
168+
'null' : _variable,
169+
'Activation' : _activations,
170+
'BatchNorm' : _batch_norm,
171+
'BatchNorm_v1' : _batch_norm,
172+
'Cast' : _rename('cast'),
173+
'Concat' : _concat,
174+
'Convolution' : _conv2d,
175+
'Convolution_v1': _conv2d,
176+
'Deconvolution' : _conv2d_transpose,
177+
'Dropout' : _dropout,
178+
'Flatten' : _rename('flatten'),
179+
'FullyConnected': _dense,
180+
'LeakyReLU' : _leaky_relu,
181+
'Pooling' : _pooling,
182+
'Pooling_v1' : _pooling,
183+
'Reshape' : _reshape,
184+
'Softmax' : _rename('softmax'),
185+
'concat' : _concat,
186+
'max_axis' : _rename('max'),
187+
'min_axis' : _rename('min'),
188+
'reshape' : _reshape,
189+
'sum_axis' : _rename('sum'),
190+
}
191+
192+
def _convert_symbol(op_name, attrs,
193+
identity_list=_identity_list,
194+
convert_map=_convert_map):
195+
"""Convert from mxnet op to nnvm op.
196+
The converter must specify some conversions explicitly to
197+
support gluon format ops such as conv2d...
198+
199+
Parameters
200+
----------
201+
op_name : str
202+
Operator name, such as Convolution, FullyConnected
203+
attrs : dict
204+
Dict of operator attributes
205+
identity_list : list
206+
List of operators that don't require conversion
207+
convert_map : dict
208+
Dict of name : callable, where name is the op's name that
209+
require conversion to nnvm, callable are functions which
210+
take attrs and return (new_op_name, new_attrs)
211+
212+
Returns
213+
-------
214+
(op_name, attrs)
215+
Converted (op_name, attrs) for nnvm.
216+
"""
217+
if op_name in identity_list:
218+
pass
219+
elif op_name in convert_map:
220+
op_name, attrs = convert_map[op_name](attrs)
221+
else:
222+
_raise_not_supported('Operator: ' + op_name)
223+
op = getattr(_sym, op_name, None)
224+
if not op:
225+
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
226+
return op, attrs
227+
228+
def _is_mxnet_group_symbol(symbol):
229+
"""Internal check for mxnet group symbol."""
230+
return len(symbol.list_outputs()) > 1
231+
232+
def _as_list(arr):
233+
"""Force being a list, ignore if already is."""
234+
if isinstance(arr, list):
235+
return arr
236+
return [arr]
237+
238+
def _from_mxnet_impl(symbol, graph):
239+
"""Convert mxnet symbol to nnvm implementation.
240+
Reconstruct a nnvm symbol by traversing the mxnet symbol.
241+
242+
Parameters
243+
----------
244+
symbol : mxnet.sym.Symbol
245+
Incompatible symbol from mxnet, sharing similar graph structure.
246+
The op_name and attrs inside are not always compatible.
247+
graph : dict
248+
Reusable nodes are stored in graph.
249+
250+
Returns:
251+
-------
252+
nnvm.sym.Symbol
253+
Converted symbol
254+
"""
255+
try:
256+
from mxnet import sym as mx_sym
257+
except ImportError as e:
258+
raise ImportError('{}. MXNet is required to parse symbols.'.format(e))
259+
260+
if not isinstance(symbol, mx_sym.Symbol):
261+
raise ValueError("Provided {}, while MXNet symbol is expected", type(symbol))
262+
263+
if _is_mxnet_group_symbol(symbol):
264+
return [_from_mxnet_impl(s, graph) for s in symbol]
265+
266+
name = symbol.attr('name')
267+
node = graph.get(name, None)
268+
if node:
269+
return node
270+
# op_name = symbol.attr('op_name')
271+
if symbol.get_children():
272+
op_name = symbol.attr('op_name')
273+
else:
274+
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
275+
attr = symbol.list_attr()
276+
new_op, new_attr = _convert_symbol(op_name, attr)
277+
if new_op == _sym.Variable:
278+
node = new_op(name=name, **new_attr)
279+
else:
280+
childs = symbol.get_children()
281+
childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)]
282+
childs = [x for y in childs for x in _as_list(y)] # expand group symbol
283+
node = new_op(name=name, *childs, **new_attr)
284+
graph[name] = node
285+
return node
286+
287+
288+
def from_mxnet(symbol):
289+
"""Convert from mxnet.Symbol to compatible nnvm.Symbol
290+
291+
Parameters
292+
----------
293+
symbol : mxnet.Symbol
294+
MXNet symbol
295+
296+
Returns
297+
-------
298+
nnvm.Symbol
299+
Compatible nnvm symbol
300+
"""
301+
return _from_mxnet_impl(symbol, {})
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from __future__ import absolute_import
2+
from . import mlp, resnet, vgg
3+
4+
_num_class = 1000
5+
6+
# mlp fc
7+
mx_mlp = mlp.get_symbol(_num_class)
8+
nnvm_mlp = mlp.get_symbol_nnvm(_num_class)
9+
10+
# resnet fc
11+
mx_resnet = {}
12+
nnvm_resnet = {}
13+
for num_layer in [18, 34, 50, 101, 152, 200, 269]:
14+
mx_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3,224,224')
15+
nnvm_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3, 224, 224', lib='nnvm')
16+
17+
# vgg fc
18+
mx_vgg = {}
19+
nnvm_vgg = {}
20+
for num_layer in [11, 13, 16, 19]:
21+
mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
22+
nnvm_vgg[num_layer] = vgg.get_symbol_nnvm(_num_class, num_layer)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
a simple multilayer perceptron
20+
"""
21+
import mxnet as mx
22+
import nnvm
23+
24+
def get_symbol(num_classes=10, **kwargs):
25+
data = mx.symbol.Variable('data')
26+
data = mx.sym.Flatten(data=data)
27+
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
28+
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
29+
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
30+
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
31+
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
32+
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
33+
return mlp
34+
35+
def get_symbol_nnvm(num_classes=10, **kwargs):
36+
data = nnvm.symbol.Variable('data')
37+
data = nnvm.sym.flatten(data=data)
38+
fc1 = nnvm.symbol.dense(data = data, name='fc1', units=128)
39+
act1 = nnvm.symbol.relu(data = fc1, name='relu1')
40+
fc2 = nnvm.symbol.dense(data = act1, name = 'fc2', units = 64)
41+
act2 = nnvm.symbol.relu(data = fc2, name='relu2')
42+
fc3 = nnvm.symbol.dense(data = act2, name='fc3', units=num_classes)
43+
mlp = nnvm.symbol.softmax(data = fc3, name = 'softmax')
44+
return mlp

0 commit comments

Comments
 (0)