|
| 1 | +"""MXNet symbol frontend.""" |
| 2 | +from __future__ import absolute_import as _abs |
| 3 | +import json |
| 4 | +from .. import symbol as _sym |
| 5 | + |
| 6 | +__all__ = ['from_mxnet'] |
| 7 | + |
| 8 | +def _required_attr(attr, key): |
| 9 | + assert isinstance(attr, dict) |
| 10 | + if key not in attr: |
| 11 | + raise AttributeError("Required attribute {} not found.".format(key)) |
| 12 | + return attr[key] |
| 13 | + |
| 14 | +def _raise_not_supported(attr, op='nnvm'): |
| 15 | + err = "{} is not supported in {}.".format(attr, op) |
| 16 | + raise NotImplementedError(err) |
| 17 | + |
| 18 | +def _warn_not_used(attr, op='nnvm'): |
| 19 | + import warnings |
| 20 | + err = "{} is ignored in {}.".format(attr, op) |
| 21 | + warnings.warn(err) |
| 22 | + |
| 23 | +def _parse_tshape(tshape): |
| 24 | + """Parse tshape in string.""" |
| 25 | + return [int(x.strip()) for x in tshape.strip('()').split(',')] |
| 26 | + |
| 27 | +def _parse_bool_str(attr, key, default='False'): |
| 28 | + """Parse bool string to boolean.""" |
| 29 | + return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes'] |
| 30 | + |
| 31 | +def _rename(new_name): |
| 32 | + def impl(attr): |
| 33 | + return new_name, attr |
| 34 | + return impl |
| 35 | + |
| 36 | +def _variable(attrs): |
| 37 | + return "Variable", attrs |
| 38 | + |
| 39 | +def _pooling(attrs): |
| 40 | + kernel = _parse_tshape(_required_attr(attrs, 'kernel')) |
| 41 | + if len(kernel) != 2: |
| 42 | + _raise_not_supported('non-2d kernel', 'pool_2d') |
| 43 | + global_pool = 'global' if _parse_bool_str(attrs, 'global_pool') else '' |
| 44 | + pool_type = _required_attr(attrs, 'pool_type') |
| 45 | + if pool_type not in ['avg', 'max']: |
| 46 | + _raise_not_supported('non-avg/max', 'pool2d') |
| 47 | + op_name, new_attrs = '_'.join([global_pool, pool_type, 'pool2d']).strip('_'), {} |
| 48 | + # new_attrs['layout'] = 'NCHW' |
| 49 | + if not global_pool: |
| 50 | + new_attrs['pool_size'] = kernel |
| 51 | + new_attrs['strides'] = attrs.get('stride', (1, 1)) |
| 52 | + new_attrs['padding'] = attrs.get('pad', (0, 0)) |
| 53 | + new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full') |
| 54 | + return op_name, new_attrs |
| 55 | + |
| 56 | +def _batch_norm(attrs): |
| 57 | + if _parse_bool_str(attrs, 'output_mean_var'): |
| 58 | + _raise_not_supported('output_mean_var', 'batch_norm') |
| 59 | + if _parse_bool_str(attrs, 'fix_gamma'): |
| 60 | + _warn_not_used('fix_gamma', 'batch_norm') |
| 61 | + if _parse_bool_str(attrs, 'use_global_stats'): |
| 62 | + _warn_not_used('use_global_stats', 'batch_norm') |
| 63 | + if _parse_bool_str(attrs, 'momentum'): |
| 64 | + _warn_not_used('momentum', 'batch_norm') |
| 65 | + op_name, new_attrs = 'batch_norm', {} |
| 66 | + new_attrs['axis'] = attrs.get('axis', 1) |
| 67 | + new_attrs['epsilon'] = attrs.get('eps', 0.001) |
| 68 | + new_attrs['center'] = True |
| 69 | + new_attrs['scale'] = True |
| 70 | + return op_name, new_attrs |
| 71 | + |
| 72 | +def _concat(attrs): |
| 73 | + op_name = 'concatenate' |
| 74 | + new_attrs = {'axis': attrs.get('dim', 1)} |
| 75 | + return op_name, new_attrs |
| 76 | + |
| 77 | +def _conv2d(attrs): |
| 78 | + kernel = _parse_tshape(_required_attr(attrs, 'kernel')) |
| 79 | + if len(kernel) != 2: |
| 80 | + _raise_not_supported('non 2d kernel', 'conv2d') |
| 81 | + layout = attrs.get('layout', 'NCHW') |
| 82 | + if layout not in ['NCHW', 'NHWC']: |
| 83 | + _raise_not_supported('layout: ' + layout, 'conv2d') |
| 84 | + op_name, new_attrs = 'conv2d', {} |
| 85 | + new_attrs['channels'] = _required_attr(attrs, 'num_filter') |
| 86 | + new_attrs['kernel_size'] = kernel |
| 87 | + new_attrs['strides'] = attrs.get('stride', (1, 1)) |
| 88 | + new_attrs['padding'] = attrs.get('pad', (0, 0)) |
| 89 | + new_attrs['dilation'] = attrs.get('dilate', (1, 1)) |
| 90 | + new_attrs['groups'] = attrs.get('num_group', 1) |
| 91 | + new_attrs['layout'] = layout |
| 92 | + new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False' |
| 93 | + return op_name, new_attrs |
| 94 | + |
| 95 | +def _conv2d_transpose(attrs): |
| 96 | + if 'target_shape' in attrs: |
| 97 | + _raise_not_supported('target_shape', 'conv2d_transpose') |
| 98 | + kernel = _parse_tshape(_required_attr(attrs, 'kernel')) |
| 99 | + if len(kernel) != 2: |
| 100 | + _raise_not_supported('non-2d kernel', 'conv2d_transpose') |
| 101 | + layout = attrs.get('layout', 'NCHW') |
| 102 | + if layout not in ['NCHW', 'NHWC']: |
| 103 | + _raise_not_supported('layout: ' + layout, 'conv2d_transpose') |
| 104 | + op_name, new_attrs = 'conv2d_transpose', {} |
| 105 | + new_attrs['channels'] = _required_attr(attrs, 'num_filter') |
| 106 | + new_attrs['kernel_size'] = kernel |
| 107 | + new_attrs['strides'] = attrs.get('stride', (1, 1)) |
| 108 | + new_attrs['output_padding'] = attrs.get('adj', (0, 0)) |
| 109 | + new_attrs['padding'] = attrs.get('pad', (0, 0)) |
| 110 | + new_attrs['dilation'] = attrs.get('dilate', (1, 1)) |
| 111 | + new_attrs['groups'] = attrs.get('num_group', 1) |
| 112 | + new_attrs['layout'] = layout |
| 113 | + new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') |
| 114 | + return op_name, new_attrs |
| 115 | + |
| 116 | +def _dense(attrs): |
| 117 | + op_name, new_attrs = 'dense', {} |
| 118 | + new_attrs['units'] = _required_attr(attrs, 'num_hidden') |
| 119 | + new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') |
| 120 | + return op_name, new_attrs |
| 121 | + |
| 122 | +def _dropout(attrs): |
| 123 | + op_name, new_attrs = 'dropout', {} |
| 124 | + new_attrs['rate'] = attrs.get('p', 0.5) |
| 125 | + return op_name, new_attrs |
| 126 | + |
| 127 | +def _leaky_relu(attrs): |
| 128 | + act_type = _required_attr(attrs, 'act_type') |
| 129 | + if act_type not in ['leaky']: |
| 130 | + _raise_not_supported('act_type: ' + act_type) |
| 131 | + op_name, new_attrs = 'leaky_relu', {} |
| 132 | + new_attrs['alpha'] = attrs.get('slope', 0.25) |
| 133 | + return op_name, new_attrs |
| 134 | + |
| 135 | +def _activations(attrs): |
| 136 | + act_type = _required_attr(attrs, 'act_type') |
| 137 | + if act_type not in ['relu', 'sigmoid', 'tanh']: |
| 138 | + _raise_not_supported('act_type: ' + act_type) |
| 139 | + op_name, new_attrs = act_type, {} |
| 140 | + return op_name, new_attrs |
| 141 | + |
| 142 | +def _reshape(attrs): |
| 143 | + if _parse_bool_str(attrs, 'reverse'): |
| 144 | + _raise_not_supported('reverse', 'reshape') |
| 145 | + op_name, new_attrs = 'reshape', {} |
| 146 | + new_attrs['shape'] = _required_attr(attrs, 'shape') |
| 147 | + return op_name, new_attrs |
| 148 | + |
| 149 | +def _split(attrs): |
| 150 | + if _parse_bool_str(attrs, 'squeeze_axis'): |
| 151 | + _raise_not_supported('squeeze_axis', 'split') |
| 152 | + op_name, new_attrs = 'split', {} |
| 153 | + new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs') |
| 154 | + new_attrs['axis'] = attrs.get('axis', 1) |
| 155 | + return op_name, new_attrs |
| 156 | + |
| 157 | +_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', |
| 158 | + '__div_symbol__', '__mul_scalar__', '__mul_symbol__', |
| 159 | + '__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__', |
| 160 | + '__rsub_scalar__', '__sub_scalar__', '__sub_symbol__', |
| 161 | + 'broadcast_add', 'broadcast_div', 'broadcast_mul', |
| 162 | + 'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add', |
| 163 | + 'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp', |
| 164 | + 'flatten', 'log', 'log_softmax', 'max', 'min', 'negative', |
| 165 | + 'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose'] |
| 166 | + |
| 167 | +_convert_map = { |
| 168 | + 'null' : _variable, |
| 169 | + 'Activation' : _activations, |
| 170 | + 'BatchNorm' : _batch_norm, |
| 171 | + 'BatchNorm_v1' : _batch_norm, |
| 172 | + 'Cast' : _rename('cast'), |
| 173 | + 'Concat' : _concat, |
| 174 | + 'Convolution' : _conv2d, |
| 175 | + 'Convolution_v1': _conv2d, |
| 176 | + 'Deconvolution' : _conv2d_transpose, |
| 177 | + 'Dropout' : _dropout, |
| 178 | + 'Flatten' : _rename('flatten'), |
| 179 | + 'FullyConnected': _dense, |
| 180 | + 'LeakyReLU' : _leaky_relu, |
| 181 | + 'Pooling' : _pooling, |
| 182 | + 'Pooling_v1' : _pooling, |
| 183 | + 'Reshape' : _reshape, |
| 184 | + 'Softmax' : _rename('softmax'), |
| 185 | + 'concat' : _concat, |
| 186 | + 'max_axis' : _rename('max'), |
| 187 | + 'min_axis' : _rename('min'), |
| 188 | + 'reshape' : _reshape, |
| 189 | + 'sum_axis' : _rename('sum'), |
| 190 | +} |
| 191 | + |
| 192 | +def _convert_symbol(op_name, attrs, |
| 193 | + identity_list=_identity_list, |
| 194 | + convert_map=_convert_map): |
| 195 | + """Convert from mxnet op to nnvm op. |
| 196 | + The converter must specify some conversions explicitly to |
| 197 | + support gluon format ops such as conv2d... |
| 198 | +
|
| 199 | + Parameters |
| 200 | + ---------- |
| 201 | + op_name : str |
| 202 | + Operator name, such as Convolution, FullyConnected |
| 203 | + attrs : dict |
| 204 | + Dict of operator attributes |
| 205 | + identity_list : list |
| 206 | + List of operators that don't require conversion |
| 207 | + convert_map : dict |
| 208 | + Dict of name : callable, where name is the op's name that |
| 209 | + require conversion to nnvm, callable are functions which |
| 210 | + take attrs and return (new_op_name, new_attrs) |
| 211 | +
|
| 212 | + Returns |
| 213 | + ------- |
| 214 | + (op_name, attrs) |
| 215 | + Converted (op_name, attrs) for nnvm. |
| 216 | + """ |
| 217 | + if op_name in identity_list: |
| 218 | + pass |
| 219 | + elif op_name in convert_map: |
| 220 | + op_name, attrs = convert_map[op_name](attrs) |
| 221 | + else: |
| 222 | + _raise_not_supported('Operator: ' + op_name) |
| 223 | + op = getattr(_sym, op_name, None) |
| 224 | + if not op: |
| 225 | + raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name)) |
| 226 | + return op, attrs |
| 227 | + |
| 228 | +def _is_mxnet_group_symbol(symbol): |
| 229 | + """Internal check for mxnet group symbol.""" |
| 230 | + return len(symbol.list_outputs()) > 1 |
| 231 | + |
| 232 | +def _as_list(arr): |
| 233 | + """Force being a list, ignore if already is.""" |
| 234 | + if isinstance(arr, list): |
| 235 | + return arr |
| 236 | + return [arr] |
| 237 | + |
| 238 | +def _from_mxnet_impl(symbol, graph): |
| 239 | + """Convert mxnet symbol to nnvm implementation. |
| 240 | + Reconstruct a nnvm symbol by traversing the mxnet symbol. |
| 241 | +
|
| 242 | + Parameters |
| 243 | + ---------- |
| 244 | + symbol : mxnet.sym.Symbol |
| 245 | + Incompatible symbol from mxnet, sharing similar graph structure. |
| 246 | + The op_name and attrs inside are not always compatible. |
| 247 | + graph : dict |
| 248 | + Reusable nodes are stored in graph. |
| 249 | +
|
| 250 | + Returns: |
| 251 | + ------- |
| 252 | + nnvm.sym.Symbol |
| 253 | + Converted symbol |
| 254 | + """ |
| 255 | + try: |
| 256 | + from mxnet import sym as mx_sym |
| 257 | + except ImportError as e: |
| 258 | + raise ImportError('{}. MXNet is required to parse symbols.'.format(e)) |
| 259 | + |
| 260 | + if not isinstance(symbol, mx_sym.Symbol): |
| 261 | + raise ValueError("Provided {}, while MXNet symbol is expected", type(symbol)) |
| 262 | + |
| 263 | + if _is_mxnet_group_symbol(symbol): |
| 264 | + return [_from_mxnet_impl(s, graph) for s in symbol] |
| 265 | + |
| 266 | + name = symbol.attr('name') |
| 267 | + node = graph.get(name, None) |
| 268 | + if node: |
| 269 | + return node |
| 270 | + # op_name = symbol.attr('op_name') |
| 271 | + if symbol.get_children(): |
| 272 | + op_name = symbol.attr('op_name') |
| 273 | + else: |
| 274 | + op_name = json.loads(symbol.tojson())['nodes'][0]['op'] |
| 275 | + attr = symbol.list_attr() |
| 276 | + new_op, new_attr = _convert_symbol(op_name, attr) |
| 277 | + if new_op == _sym.Variable: |
| 278 | + node = new_op(name=name, **new_attr) |
| 279 | + else: |
| 280 | + childs = symbol.get_children() |
| 281 | + childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)] |
| 282 | + childs = [x for y in childs for x in _as_list(y)] # expand group symbol |
| 283 | + node = new_op(name=name, *childs, **new_attr) |
| 284 | + graph[name] = node |
| 285 | + return node |
| 286 | + |
| 287 | + |
| 288 | +def from_mxnet(symbol): |
| 289 | + """Convert from mxnet.Symbol to compatible nnvm.Symbol |
| 290 | +
|
| 291 | + Parameters |
| 292 | + ---------- |
| 293 | + symbol : mxnet.Symbol |
| 294 | + MXNet symbol |
| 295 | +
|
| 296 | + Returns |
| 297 | + ------- |
| 298 | + nnvm.Symbol |
| 299 | + Compatible nnvm symbol |
| 300 | + """ |
| 301 | + return _from_mxnet_impl(symbol, {}) |
0 commit comments