Skip to content

Commit ca6a688

Browse files
committed
add list_model.py; quantization resnet101_v1, resnet152_v1 resnet18_v1 resnet34_v1 resnet34_v2, resnet152_v2, resnet18_v1b, resnet34_v1b, resnet50_v1b, resnet50_v1b, resnet101_v1b, etc...; tfm_ops.mean.rewrite/fuse_transpose and corresponding unit tests
1 parent abc992a commit ca6a688

File tree

83 files changed

+1348
-30
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+1348
-30
lines changed

python/mrt/cvm_op.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,24 @@ def forward(self, is_train, req, in_data, out_data, aux):
157157
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
158158
assert False
159159

160+
class RightShiftV2(mx.operator.CustomOp):
161+
def __init__(self, shift_bit, **kwargs):
162+
super(RightShiftV2, self).__init__(**kwargs)
163+
self.sb = int(shift_bit)
164+
assert self.sb > 0
165+
166+
def forward(self, is_train, req, in_data, out_data, aux):
167+
assert is_train == False
168+
X = in_data[0]
169+
out = X.round()
170+
if self.sb > 1:
171+
out = out / 2**self.sb
172+
out = out.round()
173+
self.assign(out_data[0], req[0], out)
174+
175+
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
176+
assert False
177+
160178
class Annotate(mx.operator.CustomOp):
161179
def __init__(self, in_prec, out_prec, anno_type):
162180
super(Annotate, self).__init__()
@@ -270,6 +288,27 @@ def infer_type(self, in_type):
270288
def create_operator(self, ctx, shapes, dtypes):
271289
return RightShift(self.precision, self.shift_bit)
272290

291+
@mx.operator.register("right_shift")
292+
class RightShiftV2Prop(mx.operator.CustomOpProp):
293+
""" MxNet right_shift operator property class.
294+
"""
295+
def __init__(self, shift_bit=0):
296+
self.shift_bit = shift_bit
297+
super(RightShiftV2Prop, self).__init__(need_top_grad=False)
298+
def list_arguments(self):
299+
return ['data']
300+
def list_outputs(self):
301+
return ['output']
302+
def infer_shape(self, in_shape):
303+
X_shape = in_shape[0]
304+
out_shape = in_shape[0]
305+
return [X_shape], [out_shape], []
306+
def infer_type(self, in_type):
307+
X_type = in_type[0]
308+
return [X_type], [X_type], []
309+
def create_operator(self, ctx, shapes, dtypes):
310+
return RightShiftV2(self.shift_bit)
311+
273312
@mx.operator.register("cvm_lut")
274313
class LUTProp(mx.operator.CustomOpProp):
275314
""" MxNet cvm_lut operator property class.

python/mrt/tfm_ops.py

Lines changed: 220 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,41 +1449,42 @@ def validate(self, op, **kwargs):
14491449
return op
14501450

14511451
def fuse_transpose(self, op, **kwargs):
1452-
""" Customized fuse_transpose pass Introduction.
1452+
return fuse_transpose_reduce(op, kwargs["infer_shapes"])
1453+
# """ Customized fuse_transpose pass Introduction.
14531454

1454-
Suppose 'keepdims' is True and the input is 'Transpose'.
1455+
# Suppose 'keepdims' is True and the input is 'Transpose'.
14551456

1456-
.. code-block:: none
1457+
# .. code-block:: none
14571458

1458-
cX
1459-
|
1460-
Transpose(axis)
1461-
|
1462-
sum(dims1)
1459+
# cX
1460+
# |
1461+
# Transpose(axis)
1462+
# |
1463+
# sum(dims1)
14631464

1464-
then, the graph can be transformed into:
1465+
# then, the graph can be transformed into:
14651466

1466-
.. code-block:: none
1467+
# .. code-block:: none
14671468

1468-
cX
1469-
|
1470-
Sum(dims2)
1469+
# cX
1470+
# |
1471+
# Sum(dims2)
14711472

1472-
where:
1473+
# where:
14731474

1474-
.. code-block:: python
1475+
# .. code-block:: python
14751476

1476-
dims2 = [axis[i] for i in dims1]
1477-
"""
1478-
name, attr, X = op.attr('name'), op.list_attr(), op.get_children()[0]
1479-
xshp = kwargs['infer_shapes'][X.attr('name')][get_entry_id(X)]
1480-
axis = get_attr(attr, 'axis', [i for i in range(len(xshp))])
1481-
keepdims = get_attr(attr, 'keepdims', False)
1482-
if X.attr('op_name') == Transpose.op_name and not keepdims:
1483-
axes, op = get_attr(X.list_attr(), 'axes'), X.get_children()[0]
1484-
axis = [axes[i] for i in axis]
1485-
op = mx.sym.sum(op, axis=axis, keepdims=keepdims, name=name)
1486-
return op
1477+
# dims2 = [axis[i] for i in dims1]
1478+
# """
1479+
# name, attr, X = op.attr('name'), op.list_attr(), op.get_children()[0]
1480+
# xshp = kwargs['infer_shapes'][X.attr('name')][get_entry_id(X)]
1481+
# axis = get_attr(attr, 'axis', [i for i in range(len(xshp))])
1482+
# keepdims = get_attr(attr, 'keepdims', False)
1483+
# if X.attr('op_name') == Transpose.op_name and not keepdims:
1484+
# axes, op = get_attr(X.list_attr(), 'axes'), X.get_children()[0]
1485+
# axis = [axes[i] for i in axis]
1486+
# op = mx.sym.sum(op, axis=axis, keepdims=keepdims, name=name)
1487+
# return op
14871488

14881489
def calculate_ops(self, op, **kwargs):
14891490
infer_shapes = kwargs['infer_shapes']
@@ -3103,3 +3104,196 @@ def fusable_cvm_precision_attr(op):
31033104
assert is_fusable_cvm_precision(op)
31043105
attr = op.list_attr()
31053106
return get_attr(attr, 'precision'), get_attr(attr, 'shift_bit', 0)
3107+
3108+
def sum_and_rightshift(ops, axis, shift_bit):
3109+
nops = []
3110+
while ops:
3111+
cop = ops.pop()
3112+
cop = mx.sym.sum(
3113+
cop, axis=axis, keepdims=True, name=N.n('sum'))
3114+
cop = mx.sym.Custom(
3115+
cop, shift_bit=shift_bit, name=N.n('custom'),
3116+
op_type='right_shift')
3117+
nops.append(cop)
3118+
ops = nops
3119+
return ops
3120+
3121+
3122+
@register_transformer("mean")
3123+
class Mean(Transformer):
3124+
def fuse_transpose(self, op, **kwargs):
3125+
return fuse_transpose_reduce(op, kwargs["infer_shapes"])
3126+
def rewrite(self, op, **kwargs):
3127+
name = op.attr('name')
3128+
return self._decompose_axis(op, kwargs['infer_shapes'])
3129+
3130+
def _decompose_axis(self, op, infer_shapes):
3131+
name = op.attr('name')
3132+
attr, childs = op.list_attr(), sym_iter(op.get_children())
3133+
3134+
axis = eval(attr['axis'])
3135+
if isinstance(axis, int):
3136+
axis = (axis,)
3137+
else:
3138+
assert isinstance(axis, tuple), (axis, type(axis))
3139+
3140+
keepdims = eval(attr.get('keepdims', 'False'))
3141+
3142+
exclude = eval(attr.get('exclude', 'False'))
3143+
if exclude:
3144+
raise NotImplementedError
3145+
3146+
op = childs[0]
3147+
shp = infer_shapes[op.attr('name')][get_entry_id(op)]
3148+
prod = int(nd.prod(nd.array([shp[ax] for ax in axis])).asscalar())
3149+
power_value = int(math.log2(prod))
3150+
assert 1<<power_value == prod, \
3151+
"unsupported operator mean, axis: {}, shp: {}".format(axis, shp)
3152+
3153+
MAXIMUM_SIZE = 128
3154+
assert MAXIMUM_SIZE > 1
3155+
if prod <= MAXIMUM_SIZE:
3156+
op = mx.sym.sum(
3157+
op, axis=axis, keepdims=keepdims, name=N.n('sum'))
3158+
if prod == 1:
3159+
return op
3160+
op = mx.sym.Custom(
3161+
op, shift_bit=int(math.log2(prod)), name=name,
3162+
op_type='right_shift')
3163+
return op
3164+
3165+
# TODO(ryt): select batch_axis
3166+
axis_set = set(axis)
3167+
oaxes = list(range(len(shp)))
3168+
naxes = sorted(axis)
3169+
axes = [ax for ax in oaxes if ax not in axis_set] + naxes
3170+
nshp = tuple(
3171+
[sz for ax,sz in enumerate(shp) if ax not in axis_set]) + \
3172+
(prod,)
3173+
transposed = False
3174+
if axes != oaxes:
3175+
transposed = True
3176+
op = mx.sym.transpose(op, axes=axes, name=N.n('transpose'))
3177+
reshaped = False
3178+
if len(naxes) > 1:
3179+
reshaped = True
3180+
op = mx.sym.reshape(op, shape=nshp, name=N.n('reshape'))
3181+
ops = []
3182+
eax = len(nshp) - 1
3183+
for i in range(0, prod, MAXIMUM_SIZE):
3184+
cop = mx.sym.slice_axis(
3185+
op, axis=eax, begin=i, end=i+MAXIMUM_SIZE,
3186+
name=N.n('slice_axis'))
3187+
ops.append(cop)
3188+
sb = int(math.log2(MAXIMUM_SIZE))
3189+
ops = sum_and_rightshift(ops, eax, sb)
3190+
while len(ops) > MAXIMUM_SIZE:
3191+
assert MAXIMUM_SIZE % len(ops) == 0
3192+
nops = []
3193+
for i in range(0, len(ops), MAXIMUM_SIZE):
3194+
cop = mx.sym.concat(
3195+
*ops[i:i+MAXIMUM_SIZE], dim=eax, name=N.n('concat'))
3196+
nops.append(cop)
3197+
ops = sum_and_rightshift(nops, eax, sb)
3198+
res_sz = len(ops)
3199+
assert res_sz > 1
3200+
sb = int(math.log2(res_sz))
3201+
op = mx.sym.concat(*ops, dim=eax, name=N.n('concat'))
3202+
if keepdims:
3203+
op = mx.sym.sum(op, axis=eax, keepdims=True, name=N.n('sum'))
3204+
if reshaped:
3205+
for i in range(1, len(naxes)):
3206+
op = mx.sym.expand_dims(
3207+
op, axis=i+len(nshp)-1, name=N.n('expand_dims'))
3208+
if transposed:
3209+
raxes = [0] * len(axes)
3210+
for i, ax in enumerate(axes):
3211+
raxes[ax] = i
3212+
op = mx.sym.tranpose(op, axes=raxes, name=N.n('transpose'))
3213+
else:
3214+
op = mx.sym.sum(op, axis=eax, name=N.n('sum'))
3215+
op = mx.sym.Custom(
3216+
op, shift_bit=sb, name=name, op_type='right_shift')
3217+
return op
3218+
3219+
def fuse_transpose_reduce(op, infer_shapes):
3220+
""" fuse_tranpose for reduce op, with fuse_transpose as the only child op.
3221+
currently support `sum` and `mean`.
3222+
3223+
.. code-block:: none
3224+
3225+
cX
3226+
|
3227+
Transpose(axis)
3228+
|
3229+
op(dims1)
3230+
3231+
then, the graph can be transformed into:
3232+
3233+
.. code-block:: none
3234+
3235+
cX
3236+
|
3237+
op(dims2)
3238+
3239+
where:
3240+
3241+
.. code-block:: python
3242+
3243+
dims2 = [axis[i] for i in dims1]
3244+
3245+
if keepdims is true, we switch the order of reduce op and
3246+
transpose in an equivalent way, which could be an optimization for
3247+
cases like:
3248+
3249+
.. code-block:: none
3250+
3251+
cX
3252+
|
3253+
Transpose1
3254+
|
3255+
reduce
3256+
|
3257+
Transpose2
3258+
3259+
which in the visit of Transpose2, could be further fused into:
3260+
3261+
.. code-block:: none
3262+
cX
3263+
|
3264+
reduce
3265+
|
3266+
Transpose
3267+
"""
3268+
name, op_name = op.attr('name'), op.attr('op_name')
3269+
shp = infer_shapes[name][get_entry_id(op)]
3270+
if op_name not in [Sum.op_name, Mean.op_name]:
3271+
return op
3272+
attr, X = op.list_attr(), op.get_children()[0]
3273+
xopn = X.attr('op_name')
3274+
if xopn != Transpose.op_name:
3275+
return op
3276+
xshp = infer_shapes[X.attr('name')][get_entry_id(X)]
3277+
axis = get_attr(attr, 'axis', [i for i in range(len(xshp))])
3278+
axes, cX = get_attr(X.list_attr(), 'axes'), X.get_children()[0]
3279+
naxis = [axes[i] for i in axis]
3280+
naxis_sorted = sorted(naxis)
3281+
keepdims = get_attr(attr, 'keepdims', False)
3282+
if keepdims:
3283+
op = get_mxnet_op(op_name)(
3284+
cX, axis=naxis_sorted, keepdims=True, name=N.n("reduce"))
3285+
op = mx.sym.transpose(op, axes=axes, name=name)
3286+
else:
3287+
naxis_set = set(naxis)
3288+
naxes = [ax for ax in axes if ax not in naxis_set]
3289+
naxes_dict = {ax:i for i,ax in enumerate(sorted(naxes))}
3290+
naxes = [naxes_dict[ax] for ax in naxes]
3291+
axes_ref = [i for i in range(len(shp))]
3292+
if naxes != axes_ref:
3293+
op = get_mxnet_op(op_name)(
3294+
cX, axis=naxis_sorted, keepdims=False, name=N.n("reduce"))
3295+
op = mx.sym.transpose(op, axes=naxes, name=name)
3296+
else:
3297+
op = get_mxnet_op(op_name)(
3298+
cX, axis=naxis_sorted, keepdims=False, name=name)
3299+
return op
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
COMMON:
2+
MODEL_NAME: densenet121
3+
CALIBRATE:
4+
NUM_CALIB: 1
5+
DATASET_NAME: imagenet
6+
DEVICE_TYPE: gpu
7+
DEVICE_IDS: [0]
8+
QUANTIZE:
9+
INPUT_PRECISION: 8
10+
OUTPUT_PRECISION: 8
11+
DEVICE_TYPE: gpu
12+
DEVICE_IDS: [0]
13+
EVALUATE:
14+
BATCH: 16
15+
DEVICE_TYPE: gpu
16+
DEVICE_IDS: [0]
17+
ITER_NUM: 10000
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
COMMON:
2+
MODEL_NAME: densenet169
3+
CALIBRATE:
4+
NUM_CALIB: 1
5+
DATASET_NAME: imagenet
6+
DEVICE_TYPE: gpu
7+
DEVICE_IDS: [0]
8+
QUANTIZE:
9+
INPUT_PRECISION: 8
10+
OUTPUT_PRECISION: 8
11+
DEVICE_TYPE: gpu
12+
DEVICE_IDS: [0]
13+
EVALUATE:
14+
BATCH: 16
15+
DEVICE_TYPE: gpu
16+
DEVICE_IDS: [0]
17+
ITER_NUM: 10000
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
COMMON:
2+
MODEL_NAME: densenet201
3+
CALIBRATE:
4+
NUM_CALIB: 1
5+
DATASET_NAME: imagenet
6+
DEVICE_TYPE: gpu
7+
DEVICE_IDS: [0]
8+
QUANTIZE:
9+
INPUT_PRECISION: 8
10+
OUTPUT_PRECISION: 8
11+
DEVICE_TYPE: gpu
12+
DEVICE_IDS: [0]
13+
EVALUATE:
14+
BATCH: 16
15+
DEVICE_TYPE: gpu
16+
DEVICE_IDS: [0]
17+
ITER_NUM: 1000

tests/mrt/model_zoo/mnist.yaml renamed to tests/mrt/model_zoo/imagenet/mobilenet0.5.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
COMMON:
2-
MODEL_NAME: mnist_dapp
3-
PREPARE:
4-
INPUT_SHAPE: [-1,1,28,28]
2+
MODEL_NAME: mobilenet0.5
53
CALIBRATE:
64
NUM_CALIB: 1
7-
DATASET_NAME: mnist
5+
DATASET_NAME: imagenet
6+
LAMBD: 10
87
DEVICE_TYPE: gpu
98
DEVICE_IDS: [0]
109
QUANTIZE:

0 commit comments

Comments
 (0)