@@ -234,6 +234,11 @@ def fuse_transpose(self, op, **kwargs):
234
234
return op
235
235
236
236
def rewrite (self , op , ** kwargs ):
237
+ """ Equivalent transform of rewrite operator
238
+ Only applies when the attribute act_type equals to relu or sigmoid,
239
+ which indicates that rewrite could be directly tranformed into
240
+ the corresponding operator.
241
+ """
237
242
attr = op .list_attr ()
238
243
if attr ['act_type' ] == Relu .op_name :
239
244
op = Relu ().rewrite (op , ** kwargs )
@@ -671,7 +676,26 @@ def rewrite(self, op, **kwargs):
671
676
return op
672
677
673
678
def reduce (self , op , ** kwargs ):
674
- # TODO(ryt.dev) documentation
679
+ """ Dimension reduction function considering
680
+ both flatten cases.
681
+
682
+ Denote the input as X and transformed operator as Y.
683
+ If flatten is true, only one reduction of the high dimension input
684
+ to 2 dimension is needed.
685
+
686
+ .. math::
687
+ RX = reshape(X)
688
+ Y = FullyConnected(RX)
689
+
690
+ If flatten is false, firstly one reduction of the input to 2
691
+ dimension is needed. After FullyConnected op, the ouput should
692
+ be reshaped to the correct output shape.
693
+
694
+ .. math::
695
+ RX = reshape(X)
696
+ out = FullyConnected(RX)
697
+ Y = reshape(out)
698
+ """
675
699
name = op .attr ('name' )
676
700
attr , childs = op .list_attr (), sym_iter (op .get_children ())
677
701
cns = [c .attr ('name' ) for c in childs ]
@@ -1979,12 +2003,13 @@ def fuse_transpose(self, op, **kwargs):
1979
2003
return _ft_multi_input (op )
1980
2004
1981
2005
def rewrite (self , op , ** kwargs ):
2006
+ """ validate the infer_shapes of lhs and rhs must be the same
2007
+ thus this op could be rewrite into broadcast_mul
2008
+ corresponding cvm op would be optimized at compile time
2009
+ """
1982
2010
name , op_name = op .attr ('name' ), op .attr ('op_name' )
1983
2011
childs = sym_iter (op .get_children ())
1984
2012
1985
- # validate the infer_shapes of lhs and rhs must be the same
1986
- # thus this op could be rewrite into broadcast_mul
1987
- # corresponding cvm op would be optimized at compile time
1988
2013
ln , rn = [c .attr ('name' ) for c in childs ]
1989
2014
infer_shapes = kwargs ['infer_shapes' ]
1990
2015
lshp , rshp = infer_shapes [ln ], infer_shapes [rn ]
0 commit comments