Skip to content

Commit abc992a

Browse files
committed
[doc] tfm_ops FullyConnected.reduce ElemwiseMul.rewrite Activation.rewrite
1 parent 4afe38a commit abc992a

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

docs/mrt/api/operator.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ MxNet Supported Operators are listed as below:
9191

9292
+ :py:class:`ElemwiseAdd <mrt.tfm_ops.ElemwiseAdd>`
9393
+ :py:class:`ElemwiseSub <mrt.tfm_ops.ElemwiseSub>`
94+
+ :py:class:`ElemwiseMul <mrt.tfm_ops.ElemwiseMul>`
9495
+ :py:class:`Clip <mrt.tfm_ops.Clip>`
9596
+ :py:class:`negative <mrt.tfm_ops.Negative>`
9697
+ :py:class:`abs <mrt.tfm_ops.Abs>`
@@ -137,7 +138,7 @@ MxNet Supported Operators are listed as below:
137138
:members: rewrite
138139

139140
.. autoclass:: mrt.tfm_ops.Activation
140-
:members: validate
141+
:members: validate, rewrite
141142

142143
.. autoclass:: mrt.tfm_ops.Convolution
143144
:members: rewrite, quantize
@@ -170,7 +171,7 @@ MxNet Supported Operators are listed as below:
170171
:members:
171172

172173
.. autoclass:: mrt.tfm_ops.FullyConnected
173-
:members: rewrite, quantize
174+
:members: rewrite, reduce, quantize
174175

175176
.. autoclass:: mrt.tfm_ops.Sigmoid
176177
:members: quantize
@@ -262,10 +263,12 @@ MxNet Supported Operators are listed as below:
262263
.. autoclass:: mrt.tfm_ops.ElemwiseAdd
263264
:members: fuse_transpose, quantize
264265

265-
266266
.. autoclass:: mrt.tfm_ops.ElemwiseSub
267267
:members: fuse_transpose, quantize
268268

269+
.. autoclass:: mrt.tfm_ops.ElemwiseMul
270+
:members: rewrite
271+
269272
.. autoclass:: mrt.tfm_ops.Dropout
270273
:members: fuse_transpose
271274

python/mrt/tfm_ops.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ def fuse_transpose(self, op, **kwargs):
234234
return op
235235

236236
def rewrite(self, op, **kwargs):
237+
""" Equivalent transform of rewrite operator
238+
Only applies when the attribute act_type equals to relu or sigmoid,
239+
which indicates that rewrite could be directly tranformed into
240+
the corresponding operator.
241+
"""
237242
attr = op.list_attr()
238243
if attr['act_type'] == Relu.op_name:
239244
op = Relu().rewrite(op, **kwargs)
@@ -671,7 +676,26 @@ def rewrite(self, op, **kwargs):
671676
return op
672677

673678
def reduce(self, op, **kwargs):
674-
# TODO(ryt.dev) documentation
679+
""" Dimension reduction function considering
680+
both flatten cases.
681+
682+
Denote the input as X and transformed operator as Y.
683+
If flatten is true, only one reduction of the high dimension input
684+
to 2 dimension is needed.
685+
686+
.. math::
687+
RX = reshape(X)
688+
Y = FullyConnected(RX)
689+
690+
If flatten is false, firstly one reduction of the input to 2
691+
dimension is needed. After FullyConnected op, the ouput should
692+
be reshaped to the correct output shape.
693+
694+
.. math::
695+
RX = reshape(X)
696+
out = FullyConnected(RX)
697+
Y = reshape(out)
698+
"""
675699
name = op.attr('name')
676700
attr, childs = op.list_attr(), sym_iter(op.get_children())
677701
cns = [c.attr('name') for c in childs]
@@ -1979,12 +2003,13 @@ def fuse_transpose(self, op, **kwargs):
19792003
return _ft_multi_input(op)
19802004

19812005
def rewrite(self, op, **kwargs):
2006+
""" validate the infer_shapes of lhs and rhs must be the same
2007+
thus this op could be rewrite into broadcast_mul
2008+
corresponding cvm op would be optimized at compile time
2009+
"""
19822010
name, op_name = op.attr('name'), op.attr('op_name')
19832011
childs = sym_iter(op.get_children())
19842012

1985-
# validate the infer_shapes of lhs and rhs must be the same
1986-
# thus this op could be rewrite into broadcast_mul
1987-
# corresponding cvm op would be optimized at compile time
19882013
ln, rn = [c.attr('name') for c in childs]
19892014
infer_shapes = kwargs['infer_shapes']
19902015
lshp, rshp = infer_shapes[ln], infer_shapes[rn]

0 commit comments

Comments
 (0)