@@ -1449,41 +1449,42 @@ def validate(self, op, **kwargs):
1449
1449
return op
1450
1450
1451
1451
def fuse_transpose (self , op , ** kwargs ):
1452
- """ Customized fuse_transpose pass Introduction.
1452
+ return fuse_transpose_reduce (op , kwargs ["infer_shapes" ])
1453
+ # """ Customized fuse_transpose pass Introduction.
1453
1454
1454
- Suppose 'keepdims' is True and the input is 'Transpose'.
1455
+ # Suppose 'keepdims' is True and the input is 'Transpose'.
1455
1456
1456
- .. code-block:: none
1457
+ # .. code-block:: none
1457
1458
1458
- cX
1459
- |
1460
- Transpose(axis)
1461
- |
1462
- sum(dims1)
1459
+ # cX
1460
+ # |
1461
+ # Transpose(axis)
1462
+ # |
1463
+ # sum(dims1)
1463
1464
1464
- then, the graph can be transformed into:
1465
+ # then, the graph can be transformed into:
1465
1466
1466
- .. code-block:: none
1467
+ # .. code-block:: none
1467
1468
1468
- cX
1469
- |
1470
- Sum(dims2)
1469
+ # cX
1470
+ # |
1471
+ # Sum(dims2)
1471
1472
1472
- where:
1473
+ # where:
1473
1474
1474
- .. code-block:: python
1475
+ # .. code-block:: python
1475
1476
1476
- dims2 = [axis[i] for i in dims1]
1477
- """
1478
- name , attr , X = op .attr ('name' ), op .list_attr (), op .get_children ()[0 ]
1479
- xshp = kwargs ['infer_shapes' ][X .attr ('name' )][get_entry_id (X )]
1480
- axis = get_attr (attr , 'axis' , [i for i in range (len (xshp ))])
1481
- keepdims = get_attr (attr , 'keepdims' , False )
1482
- if X .attr ('op_name' ) == Transpose .op_name and not keepdims :
1483
- axes , op = get_attr (X .list_attr (), 'axes' ), X .get_children ()[0 ]
1484
- axis = [axes [i ] for i in axis ]
1485
- op = mx .sym .sum (op , axis = axis , keepdims = keepdims , name = name )
1486
- return op
1477
+ # dims2 = [axis[i] for i in dims1]
1478
+ # """
1479
+ # name, attr, X = op.attr('name'), op.list_attr(), op.get_children()[0]
1480
+ # xshp = kwargs['infer_shapes'][X.attr('name')][get_entry_id(X)]
1481
+ # axis = get_attr(attr, 'axis', [i for i in range(len(xshp))])
1482
+ # keepdims = get_attr(attr, 'keepdims', False)
1483
+ # if X.attr('op_name') == Transpose.op_name and not keepdims:
1484
+ # axes, op = get_attr(X.list_attr(), 'axes'), X.get_children()[0]
1485
+ # axis = [axes[i] for i in axis]
1486
+ # op = mx.sym.sum(op, axis=axis, keepdims=keepdims, name=name)
1487
+ # return op
1487
1488
1488
1489
def calculate_ops (self , op , ** kwargs ):
1489
1490
infer_shapes = kwargs ['infer_shapes' ]
@@ -3103,3 +3104,196 @@ def fusable_cvm_precision_attr(op):
3103
3104
assert is_fusable_cvm_precision (op )
3104
3105
attr = op .list_attr ()
3105
3106
return get_attr (attr , 'precision' ), get_attr (attr , 'shift_bit' , 0 )
3107
+
3108
+ def sum_and_rightshift (ops , axis , shift_bit ):
3109
+ nops = []
3110
+ while ops :
3111
+ cop = ops .pop ()
3112
+ cop = mx .sym .sum (
3113
+ cop , axis = axis , keepdims = True , name = N .n ('sum' ))
3114
+ cop = mx .sym .Custom (
3115
+ cop , shift_bit = shift_bit , name = N .n ('custom' ),
3116
+ op_type = 'right_shift' )
3117
+ nops .append (cop )
3118
+ ops = nops
3119
+ return ops
3120
+
3121
+
3122
+ @register_transformer ("mean" )
3123
+ class Mean (Transformer ):
3124
+ def fuse_transpose (self , op , ** kwargs ):
3125
+ return fuse_transpose_reduce (op , kwargs ["infer_shapes" ])
3126
+ def rewrite (self , op , ** kwargs ):
3127
+ name = op .attr ('name' )
3128
+ return self ._decompose_axis (op , kwargs ['infer_shapes' ])
3129
+
3130
+ def _decompose_axis (self , op , infer_shapes ):
3131
+ name = op .attr ('name' )
3132
+ attr , childs = op .list_attr (), sym_iter (op .get_children ())
3133
+
3134
+ axis = eval (attr ['axis' ])
3135
+ if isinstance (axis , int ):
3136
+ axis = (axis ,)
3137
+ else :
3138
+ assert isinstance (axis , tuple ), (axis , type (axis ))
3139
+
3140
+ keepdims = eval (attr .get ('keepdims' , 'False' ))
3141
+
3142
+ exclude = eval (attr .get ('exclude' , 'False' ))
3143
+ if exclude :
3144
+ raise NotImplementedError
3145
+
3146
+ op = childs [0 ]
3147
+ shp = infer_shapes [op .attr ('name' )][get_entry_id (op )]
3148
+ prod = int (nd .prod (nd .array ([shp [ax ] for ax in axis ])).asscalar ())
3149
+ power_value = int (math .log2 (prod ))
3150
+ assert 1 << power_value == prod , \
3151
+ "unsupported operator mean, axis: {}, shp: {}" .format (axis , shp )
3152
+
3153
+ MAXIMUM_SIZE = 128
3154
+ assert MAXIMUM_SIZE > 1
3155
+ if prod <= MAXIMUM_SIZE :
3156
+ op = mx .sym .sum (
3157
+ op , axis = axis , keepdims = keepdims , name = N .n ('sum' ))
3158
+ if prod == 1 :
3159
+ return op
3160
+ op = mx .sym .Custom (
3161
+ op , shift_bit = int (math .log2 (prod )), name = name ,
3162
+ op_type = 'right_shift' )
3163
+ return op
3164
+
3165
+ # TODO(ryt): select batch_axis
3166
+ axis_set = set (axis )
3167
+ oaxes = list (range (len (shp )))
3168
+ naxes = sorted (axis )
3169
+ axes = [ax for ax in oaxes if ax not in axis_set ] + naxes
3170
+ nshp = tuple (
3171
+ [sz for ax ,sz in enumerate (shp ) if ax not in axis_set ]) + \
3172
+ (prod ,)
3173
+ transposed = False
3174
+ if axes != oaxes :
3175
+ transposed = True
3176
+ op = mx .sym .transpose (op , axes = axes , name = N .n ('transpose' ))
3177
+ reshaped = False
3178
+ if len (naxes ) > 1 :
3179
+ reshaped = True
3180
+ op = mx .sym .reshape (op , shape = nshp , name = N .n ('reshape' ))
3181
+ ops = []
3182
+ eax = len (nshp ) - 1
3183
+ for i in range (0 , prod , MAXIMUM_SIZE ):
3184
+ cop = mx .sym .slice_axis (
3185
+ op , axis = eax , begin = i , end = i + MAXIMUM_SIZE ,
3186
+ name = N .n ('slice_axis' ))
3187
+ ops .append (cop )
3188
+ sb = int (math .log2 (MAXIMUM_SIZE ))
3189
+ ops = sum_and_rightshift (ops , eax , sb )
3190
+ while len (ops ) > MAXIMUM_SIZE :
3191
+ assert MAXIMUM_SIZE % len (ops ) == 0
3192
+ nops = []
3193
+ for i in range (0 , len (ops ), MAXIMUM_SIZE ):
3194
+ cop = mx .sym .concat (
3195
+ * ops [i :i + MAXIMUM_SIZE ], dim = eax , name = N .n ('concat' ))
3196
+ nops .append (cop )
3197
+ ops = sum_and_rightshift (nops , eax , sb )
3198
+ res_sz = len (ops )
3199
+ assert res_sz > 1
3200
+ sb = int (math .log2 (res_sz ))
3201
+ op = mx .sym .concat (* ops , dim = eax , name = N .n ('concat' ))
3202
+ if keepdims :
3203
+ op = mx .sym .sum (op , axis = eax , keepdims = True , name = N .n ('sum' ))
3204
+ if reshaped :
3205
+ for i in range (1 , len (naxes )):
3206
+ op = mx .sym .expand_dims (
3207
+ op , axis = i + len (nshp )- 1 , name = N .n ('expand_dims' ))
3208
+ if transposed :
3209
+ raxes = [0 ] * len (axes )
3210
+ for i , ax in enumerate (axes ):
3211
+ raxes [ax ] = i
3212
+ op = mx .sym .tranpose (op , axes = raxes , name = N .n ('transpose' ))
3213
+ else :
3214
+ op = mx .sym .sum (op , axis = eax , name = N .n ('sum' ))
3215
+ op = mx .sym .Custom (
3216
+ op , shift_bit = sb , name = name , op_type = 'right_shift' )
3217
+ return op
3218
+
3219
+ def fuse_transpose_reduce (op , infer_shapes ):
3220
+ """ fuse_tranpose for reduce op, with fuse_transpose as the only child op.
3221
+ currently support `sum` and `mean`.
3222
+
3223
+ .. code-block:: none
3224
+
3225
+ cX
3226
+ |
3227
+ Transpose(axis)
3228
+ |
3229
+ op(dims1)
3230
+
3231
+ then, the graph can be transformed into:
3232
+
3233
+ .. code-block:: none
3234
+
3235
+ cX
3236
+ |
3237
+ op(dims2)
3238
+
3239
+ where:
3240
+
3241
+ .. code-block:: python
3242
+
3243
+ dims2 = [axis[i] for i in dims1]
3244
+
3245
+ if keepdims is true, we switch the order of reduce op and
3246
+ transpose in an equivalent way, which could be an optimization for
3247
+ cases like:
3248
+
3249
+ .. code-block:: none
3250
+
3251
+ cX
3252
+ |
3253
+ Transpose1
3254
+ |
3255
+ reduce
3256
+ |
3257
+ Transpose2
3258
+
3259
+ which in the visit of Transpose2, could be further fused into:
3260
+
3261
+ .. code-block:: none
3262
+ cX
3263
+ |
3264
+ reduce
3265
+ |
3266
+ Transpose
3267
+ """
3268
+ name , op_name = op .attr ('name' ), op .attr ('op_name' )
3269
+ shp = infer_shapes [name ][get_entry_id (op )]
3270
+ if op_name not in [Sum .op_name , Mean .op_name ]:
3271
+ return op
3272
+ attr , X = op .list_attr (), op .get_children ()[0 ]
3273
+ xopn = X .attr ('op_name' )
3274
+ if xopn != Transpose .op_name :
3275
+ return op
3276
+ xshp = infer_shapes [X .attr ('name' )][get_entry_id (X )]
3277
+ axis = get_attr (attr , 'axis' , [i for i in range (len (xshp ))])
3278
+ axes , cX = get_attr (X .list_attr (), 'axes' ), X .get_children ()[0 ]
3279
+ naxis = [axes [i ] for i in axis ]
3280
+ naxis_sorted = sorted (naxis )
3281
+ keepdims = get_attr (attr , 'keepdims' , False )
3282
+ if keepdims :
3283
+ op = get_mxnet_op (op_name )(
3284
+ cX , axis = naxis_sorted , keepdims = True , name = N .n ("reduce" ))
3285
+ op = mx .sym .transpose (op , axes = axes , name = name )
3286
+ else :
3287
+ naxis_set = set (naxis )
3288
+ naxes = [ax for ax in axes if ax not in naxis_set ]
3289
+ naxes_dict = {ax :i for i ,ax in enumerate (sorted (naxes ))}
3290
+ naxes = [naxes_dict [ax ] for ax in naxes ]
3291
+ axes_ref = [i for i in range (len (shp ))]
3292
+ if naxes != axes_ref :
3293
+ op = get_mxnet_op (op_name )(
3294
+ cX , axis = naxis_sorted , keepdims = False , name = N .n ("reduce" ))
3295
+ op = mx .sym .transpose (op , axes = naxes , name = name )
3296
+ else :
3297
+ op = get_mxnet_op (op_name )(
3298
+ cX , axis = naxis_sorted , keepdims = False , name = name )
3299
+ return op
0 commit comments